mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 08:15:48 +00:00
Compare commits
1 Commits
nit_error
...
additional
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
675aa517f5 |
@@ -1,6 +1,6 @@
|
||||
name: Run Playwright Tests
|
||||
name: Run Chromatic Tests
|
||||
concurrency:
|
||||
group: Run-Playwright-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on: push
|
||||
@@ -198,47 +198,43 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
# NOTE: Chromatic UI diff testing is currently disabled.
|
||||
# We are using Playwright for local and CI testing without visual regression checks.
|
||||
# Chromatic may be reintroduced in the future for UI diff testing if needed.
|
||||
chromatic-tests:
|
||||
name: Chromatic Tests
|
||||
|
||||
# chromatic-tests:
|
||||
# name: Chromatic Tests
|
||||
needs: playwright-tests
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
# needs: playwright-tests
|
||||
# runs-on:
|
||||
# [
|
||||
# runs-on,
|
||||
# runner=32cpu-linux-x64,
|
||||
# disk=large,
|
||||
# "run-id=${{ github.run_id }}",
|
||||
# ]
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@v4
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
# - name: Setup node
|
||||
# uses: actions/setup-node@v4
|
||||
# with:
|
||||
# node-version: 22
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
# - name: Install node dependencies
|
||||
# working-directory: ./web
|
||||
# run: npm ci
|
||||
- name: Download Playwright test results
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: test-results
|
||||
path: ./web/test-results
|
||||
|
||||
# - name: Download Playwright test results
|
||||
# uses: actions/download-artifact@v4
|
||||
# with:
|
||||
# name: test-results
|
||||
# path: ./web/test-results
|
||||
|
||||
# - name: Run Chromatic
|
||||
# uses: chromaui/action@latest
|
||||
# with:
|
||||
# playwright: true
|
||||
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
|
||||
# workingDir: ./web
|
||||
# env:
|
||||
# CHROMATIC_ARCHIVE_LOCATION: ./test-results
|
||||
- name: Run Chromatic
|
||||
uses: chromaui/action@latest
|
||||
with:
|
||||
playwright: true
|
||||
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
|
||||
workingDir: ./web
|
||||
env:
|
||||
CHROMATIC_ARCHIVE_LOCATION: ./test-results
|
||||
35
.github/workflows/pr-integration-tests.yml
vendored
35
.github/workflows/pr-integration-tests.yml
vendored
@@ -99,7 +99,7 @@ jobs:
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
DEV_MODE=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up -d
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack up -d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||
@@ -108,13 +108,12 @@ jobs:
|
||||
echo "Waiting for 3 minutes to ensure API server is ready..."
|
||||
sleep 180
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
docker run --rm --network danswer-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
@@ -144,27 +143,24 @@ jobs:
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
docker logs -f danswer-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
@@ -194,24 +190,15 @@ jobs:
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Start Mock Services
|
||||
run: |
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
docker run --rm --network danswer-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
@@ -221,8 +208,6 @@ jobs:
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
@@ -244,13 +229,13 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
@@ -264,4 +249,4 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
2
.vscode/launch.template.jsonc
vendored
2
.vscode/launch.template.jsonc
vendored
@@ -205,7 +205,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
|
||||
@@ -28,11 +28,11 @@ RUN apt-get update && \
|
||||
curl \
|
||||
zip \
|
||||
ca-certificates \
|
||||
libgnutls30 \
|
||||
libblkid1 \
|
||||
libmount1 \
|
||||
libsmartcols1 \
|
||||
libuuid1 \
|
||||
libgnutls30=3.7.9-2+deb12u3 \
|
||||
libblkid1=2.38.1-5+deb12u1 \
|
||||
libmount1=2.38.1-5+deb12u1 \
|
||||
libsmartcols1=2.38.1-5+deb12u1 \
|
||||
libuuid1=2.38.1-5+deb12u1 \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc \
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
"""Add checkpointing/failure handling
|
||||
|
||||
Revision ID: b7a7eee5aa15
|
||||
Revises: f39c5794c10a
|
||||
Create Date: 2025-01-24 15:17:36.763172
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7a7eee5aa15"
|
||||
down_revision = "f39c5794c10a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("checkpoint_pointer", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("poll_range_start", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("poll_range_end", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_index_attempt_cc_pair_settings_poll",
|
||||
"index_attempt",
|
||||
[
|
||||
"connector_credential_pair_id",
|
||||
"search_settings_id",
|
||||
"status",
|
||||
sa.text("time_updated DESC"),
|
||||
],
|
||||
)
|
||||
|
||||
# Drop the old IndexAttemptError table
|
||||
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
|
||||
op.drop_table("index_attempt_errors")
|
||||
|
||||
# Create the new version of the table
|
||||
op.create_table(
|
||||
"index_attempt_errors",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("index_attempt_id", sa.Integer(), nullable=False),
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_id", sa.String(), nullable=True),
|
||||
sa.Column("document_link", sa.String(), nullable=True),
|
||||
sa.Column("entity_id", sa.String(), nullable=True),
|
||||
sa.Column("failed_time_range_start", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("failed_time_range_end", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("failure_message", sa.Text(), nullable=False),
|
||||
sa.Column("is_resolved", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["index_attempt_id"],
|
||||
["index_attempt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_credential_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("SET lock_timeout = '5s'")
|
||||
|
||||
# try a few times to drop the table, this has been observed to fail due to other locks
|
||||
# blocking the drop
|
||||
NUM_TRIES = 10
|
||||
for i in range(NUM_TRIES):
|
||||
try:
|
||||
op.drop_table("index_attempt_errors")
|
||||
break
|
||||
except Exception as e:
|
||||
if i == NUM_TRIES - 1:
|
||||
raise e
|
||||
print(f"Error dropping table: {e}. Retrying...")
|
||||
|
||||
op.execute("SET lock_timeout = DEFAULT")
|
||||
|
||||
# Recreate the old IndexAttemptError table
|
||||
op.create_table(
|
||||
"index_attempt_errors",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
|
||||
sa.Column("batch", sa.Integer(), nullable=True),
|
||||
sa.Column("doc_summaries", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column("traceback", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["index_attempt_id"],
|
||||
["index_attempt.id"],
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"index_attempt_id",
|
||||
"index_attempt_errors",
|
||||
["time_created"],
|
||||
)
|
||||
|
||||
op.drop_index("ix_index_attempt_cc_pair_settings_poll")
|
||||
op.drop_column("index_attempt", "checkpoint_pointer")
|
||||
op.drop_column("index_attempt", "poll_range_start")
|
||||
op.drop_column("index_attempt", "poll_range_end")
|
||||
@@ -5,7 +5,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.connectors.slack.connector import SlackPollConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -17,7 +17,7 @@ logger = setup_logger()
|
||||
def _get_slack_document_ids_and_channels(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> dict[str, list[str]]:
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
|
||||
@@ -5,14 +5,14 @@ from langgraph.graph import StateGraph
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -33,13 +33,13 @@ def basic_graph_builder() -> StateGraph:
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="choose_tool",
|
||||
action=choose_tool,
|
||||
node="llm_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="call_tool",
|
||||
action=call_tool,
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
@@ -51,12 +51,12 @@ def basic_graph_builder() -> StateGraph:
|
||||
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool")
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
|
||||
|
||||
graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END])
|
||||
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
||||
|
||||
graph.add_edge(
|
||||
start_key="call_tool",
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
|
||||
@@ -73,7 +73,7 @@ def should_continue(state: BasicState) -> str:
|
||||
# If there are no tool calls, basic graph already streamed the answer
|
||||
END
|
||||
if state.tool_choice is None
|
||||
else "call_tool"
|
||||
else "tool_call"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -31,14 +31,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -87,11 +85,9 @@ def check_sub_answer(
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: BaseMessage | None = None
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK,
|
||||
fast_llm.invoke,
|
||||
response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
|
||||
)
|
||||
|
||||
quality_str: str = cast(str, response.content)
|
||||
@@ -100,7 +96,7 @@ def check_sub_answer(
|
||||
)
|
||||
log_result = f"Answer quality: {quality_str}"
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import merge_message_runs
|
||||
@@ -46,13 +47,11 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -111,14 +110,15 @@ def generate_sub_answer(
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
response: list[str | list[str | dict[str, Any]]] = []
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: list[str] = []
|
||||
|
||||
def stream_sub_answer() -> list[str]:
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
try:
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -142,15 +142,8 @@ def generate_sub_answer(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
|
||||
stream_sub_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -59,15 +60,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -80,7 +77,6 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
@@ -234,11 +230,7 @@ def generate_initial_answer(
|
||||
|
||||
sub_questions = all_sub_questions # Replace the original assignment
|
||||
|
||||
model = (
|
||||
graph_config.tooling.fast_llm
|
||||
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
else graph_config.tooling.primary_llm
|
||||
)
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
doc_context = format_docs(answer_generation_documents.context_documents)
|
||||
doc_context = trim_prompt_piece(
|
||||
@@ -268,16 +260,15 @@ def generate_initial_answer(
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str] = [""]
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
def stream_initial_answer() -> list[str]:
|
||||
response: list[str] = []
|
||||
try:
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -301,16 +292,9 @@ def generate_initial_answer(
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
streamed_tokens.append(content)
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
stream_initial_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -36,10 +36,7 @@ from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -50,7 +47,6 @@ from onyx.prompts.agent_search import (
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -135,12 +131,10 @@ def decompose_orig_question(
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
|
||||
dispatch_separated,
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(0, writer),
|
||||
sep_callback=dispatch_subquestion_sep(0, writer),
|
||||
@@ -160,7 +154,7 @@ def decompose_orig_question(
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError) as e:
|
||||
except LLMTimeoutError as e:
|
||||
logger.error("LLM Timeout Error - decompose orig question")
|
||||
raise e # fail loudly on this critical step
|
||||
except LLMRateLimitError as e:
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = setup_logger()
|
||||
|
||||
def route_initial_tool_choice(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> Literal["call_tool", "start_agent_search", "logging_node"]:
|
||||
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
|
||||
"""
|
||||
LangGraph edge to route to agent search.
|
||||
"""
|
||||
@@ -38,7 +38,7 @@ def route_initial_tool_choice(
|
||||
):
|
||||
return "start_agent_search"
|
||||
else:
|
||||
return "call_tool"
|
||||
return "tool_call"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
@@ -43,14 +43,14 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
|
||||
answer_refined_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -77,13 +77,13 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
# Choose the initial tool
|
||||
graph.add_node(
|
||||
node="initial_tool_choice",
|
||||
action=choose_tool,
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
# Call the tool, if required
|
||||
graph.add_node(
|
||||
node="call_tool",
|
||||
action=call_tool,
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
# Use the tool response
|
||||
@@ -168,11 +168,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
graph.add_conditional_edges(
|
||||
"initial_tool_choice",
|
||||
route_initial_tool_choice,
|
||||
["call_tool", "start_agent_search", "logging_node"],
|
||||
["tool_call", "start_agent_search", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="call_tool",
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
graph.add_edge(
|
||||
|
||||
@@ -33,15 +33,13 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -107,14 +105,11 @@ def compare_answers(
|
||||
refined_answer_improvement: bool | None = None
|
||||
# no need to stream this
|
||||
try:
|
||||
resp = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS,
|
||||
model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
|
||||
resp = model.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -44,10 +44,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -56,7 +53,6 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -138,17 +134,15 @@ def create_refined_sub_questions(
|
||||
agent_error: AgentErrorLog | None = None
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
dispatch_separated,
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(1, writer),
|
||||
sep_callback=dispatch_subquestion_sep(1, writer),
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -22,17 +22,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@@ -90,42 +84,30 @@ def extract_entities_terms(
|
||||
]
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
# Grader
|
||||
llm_response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
entity_extraction_result = EntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
entity_extraction_result = EntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
logger.error("LLM Timeout Error - extract entities terms")
|
||||
except ValueError:
|
||||
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
)
|
||||
|
||||
except LLMRateLimitError:
|
||||
logger.error("LLM Rate Limit Error - extract entities terms")
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(
|
||||
entities=[],
|
||||
relationships=[],
|
||||
terms=[],
|
||||
),
|
||||
)
|
||||
|
||||
return EntityTermExtractionUpdate(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -65,21 +66,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -98,7 +92,6 @@ from onyx.prompts.agent_search import (
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -260,12 +253,7 @@ def generate_validate_refined_answer(
|
||||
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
|
||||
)
|
||||
|
||||
model = (
|
||||
graph_config.tooling.fast_llm
|
||||
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
else graph_config.tooling.primary_llm
|
||||
)
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
relevant_docs_str = format_docs(answer_generation_documents.context_documents)
|
||||
relevant_docs_str = trim_prompt_piece(
|
||||
model.config,
|
||||
@@ -296,13 +284,13 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str] = [""]
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
def stream_refined_answer() -> list[str]:
|
||||
try:
|
||||
for message in model.stream(
|
||||
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -327,15 +315,8 @@ def generate_validate_refined_answer(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
return streamed_tokens
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
|
||||
stream_refined_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
@@ -402,20 +383,16 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
]
|
||||
|
||||
validation_model = graph_config.tooling.fast_llm
|
||||
try:
|
||||
validation_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
validation_model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
validation_response = model.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
refined_answer_quality = binary_string_test_after_answer_separator(
|
||||
text=cast(str, validation_response.content),
|
||||
positive_value=AGENT_POSITIVE_VALUE_STR,
|
||||
separator=AGENT_ANSWER_SEPARATOR,
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
refined_answer_quality = True
|
||||
logger.error("LLM Timeout Error - validate refined answer")
|
||||
|
||||
|
||||
@@ -34,16 +34,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
QUERY_REWRITING_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -71,7 +69,7 @@ def expand_queries(
|
||||
node_start_time = datetime.now()
|
||||
question = state.question
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
llm = graph_config.tooling.fast_llm
|
||||
sub_question_id = state.sub_question_id
|
||||
if sub_question_id is None:
|
||||
level, question_num = 0, 0
|
||||
@@ -90,12 +88,10 @@ def expand_queries(
|
||||
rewritten_queries = []
|
||||
|
||||
try:
|
||||
llm_response_list = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION,
|
||||
dispatch_separated,
|
||||
model.stream(
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
),
|
||||
dispatch_subquery(level, question_num, writer),
|
||||
)
|
||||
@@ -105,7 +101,7 @@ def expand_queries(
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -55,7 +55,6 @@ def rerank_documents(
|
||||
|
||||
# Note that these are passed in values from the API and are overrides which are typically None
|
||||
rerank_settings = graph_config.inputs.search_request.rerank_settings
|
||||
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
|
||||
|
||||
if rerank_settings is None:
|
||||
with get_session_context_manager() as db_session:
|
||||
@@ -63,31 +62,23 @@ def rerank_documents(
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
# Initial default: no reranking. Will be overwritten below if reranking is warranted
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if should_rerank(rerank_settings) and len(verified_documents) > 0:
|
||||
if len(verified_documents) > 1:
|
||||
if not allow_agent_reranking:
|
||||
logger.info("Use of local rerank model without GPU, skipping reranking")
|
||||
# No reranking, stay with verified_documents as default
|
||||
|
||||
else:
|
||||
# Reranking is warranted, use the rerank_sections functon
|
||||
reranked_documents = rerank_sections(
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
reranked_documents = rerank_sections(
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{len(verified_documents)} verified document(s) found, skipping reranking"
|
||||
)
|
||||
# No reranking, stay with verified_documents as default
|
||||
reranked_documents = verified_documents
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
# No reranking, stay with verified_documents as default
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if AGENT_RERANKING_STATS:
|
||||
fit_scores = get_fit_scores(verified_documents, reranked_documents)
|
||||
else:
|
||||
|
||||
@@ -25,15 +25,13 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
DOCUMENT_VERIFICATION_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -88,11 +86,8 @@ def verify_documents(
|
||||
] # default is to treat document as relevant
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
|
||||
response = fast_llm.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
@@ -101,7 +96,7 @@ def verify_documents(
|
||||
):
|
||||
verified_documents = []
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
# In this case, we decide to continue and don't raise an error, as
|
||||
# little harm in letting some docs through that are less relevant.
|
||||
logger.error("LLM Timeout Error - verify documents")
|
||||
|
||||
@@ -67,7 +67,6 @@ class GraphSearchConfig(BaseModel):
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
allow_agent_reranking: bool = False
|
||||
|
||||
|
||||
class GraphConfig(BaseModel):
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = setup_logger()
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
def choose_tool(
|
||||
def llm_tool_choice(
|
||||
state: ToolChoiceState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
@@ -28,7 +28,7 @@ def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None:
|
||||
write_custom_event("basic_response", packet, writer)
|
||||
|
||||
|
||||
def call_tool(
|
||||
def tool_call(
|
||||
state: ToolChoiceUpdate,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
@@ -43,9 +43,8 @@ from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
@@ -81,7 +80,6 @@ from onyx.tools.tool_implementations.search.search_tool import SearchResponseSum
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -397,13 +395,11 @@ def summarize_history(
|
||||
)
|
||||
|
||||
try:
|
||||
history_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
llm.invoke,
|
||||
history_response = llm.invoke(
|
||||
history_context_prompt,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
logger.error("LLM Timeout Error - summarize history")
|
||||
return (
|
||||
history # this is what is done at this point anyway, so we default to this
|
||||
|
||||
@@ -94,7 +94,6 @@ from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
@@ -108,6 +107,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
|
||||
@@ -36,15 +36,6 @@ beat_task_templates.extend(
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-checkpoint-cleanup",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import Any
|
||||
@@ -16,7 +15,6 @@ from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.result import AsyncResult
|
||||
from celery.states import READY_STATES
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -28,13 +26,7 @@ from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attem
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
|
||||
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
|
||||
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
|
||||
from onyx.background.indexing.checkpointing_utils import (
|
||||
get_index_attempts_with_old_checkpoints,
|
||||
)
|
||||
from onyx.background.indexing.job_client import SimpleJob
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.job_client import SimpleJobException
|
||||
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
@@ -42,7 +34,6 @@ from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
@@ -79,123 +70,6 @@ from shared_configs.configs import SENTRY_DSN
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
"""The different statuses the watchdog can finish with.
|
||||
|
||||
TODO: create broader success/failure/abort categories
|
||||
"""
|
||||
|
||||
UNDEFINED = "undefined"
|
||||
|
||||
SUCCEEDED = "succeeded"
|
||||
|
||||
SPAWN_FAILED = "spawn_failed" # connector spawn failed
|
||||
|
||||
BLOCKED_BY_DELETION = "blocked_by_deletion"
|
||||
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
|
||||
FENCE_NOT_FOUND = "fence_not_found" # fence does not exist
|
||||
FENCE_READINESS_TIMEOUT = (
|
||||
"fence_readiness_timeout" # fence exists but wasn't ready within the timeout
|
||||
)
|
||||
FENCE_MISMATCH = "fence_mismatch" # task and fence metadata mismatch
|
||||
TASK_ALREADY_RUNNING = "task_already_running" # task appears to be running already
|
||||
INDEX_ATTEMPT_MISMATCH = (
|
||||
"index_attempt_mismatch" # expected index attempt metadata not found in db
|
||||
)
|
||||
|
||||
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
|
||||
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
|
||||
|
||||
# the watchdog received a termination signal
|
||||
TERMINATED_BY_SIGNAL = "terminated_by_signal"
|
||||
|
||||
# the watchdog terminated the task due to no activity
|
||||
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
|
||||
|
||||
OUT_OF_MEMORY = "out_of_memory"
|
||||
|
||||
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
|
||||
|
||||
@property
|
||||
def code(self) -> int:
|
||||
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
|
||||
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
|
||||
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
|
||||
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
|
||||
IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT: 251,
|
||||
IndexingWatchdogTerminalStatus.FENCE_MISMATCH: 252,
|
||||
IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING: 253,
|
||||
IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH: 254,
|
||||
IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED: 255,
|
||||
}
|
||||
|
||||
return _ENUM_TO_CODE[self]
|
||||
|
||||
@classmethod
|
||||
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
|
||||
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
|
||||
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
|
||||
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
|
||||
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
|
||||
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
|
||||
251: IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT,
|
||||
252: IndexingWatchdogTerminalStatus.FENCE_MISMATCH,
|
||||
253: IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING,
|
||||
254: IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH,
|
||||
255: IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED,
|
||||
}
|
||||
|
||||
if code in _CODE_TO_ENUM:
|
||||
return _CODE_TO_ENUM[code]
|
||||
|
||||
return IndexingWatchdogTerminalStatus.UNDEFINED
|
||||
|
||||
|
||||
class SimpleJobResult:
|
||||
"""The data we want to have when the watchdog finishes"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.status = IndexingWatchdogTerminalStatus.UNDEFINED
|
||||
self.connector_source = None
|
||||
self.exit_code = None
|
||||
self.exception_str = None
|
||||
|
||||
status: IndexingWatchdogTerminalStatus
|
||||
connector_source: str | None
|
||||
exit_code: int | None
|
||||
exception_str: str | None
|
||||
|
||||
|
||||
class ConnectorIndexingContext(BaseModel):
|
||||
tenant_id: str | None
|
||||
cc_pair_id: int
|
||||
search_settings_id: int
|
||||
index_attempt_id: int
|
||||
|
||||
|
||||
class ConnectorIndexingLogBuilder:
|
||||
def __init__(self, ctx: ConnectorIndexingContext):
|
||||
self.ctx = ctx
|
||||
|
||||
def build(self, msg: str, **kwargs: Any) -> str:
|
||||
msg_final = (
|
||||
f"{msg}: "
|
||||
f"tenant_id={self.ctx.tenant_id} "
|
||||
f"attempt={self.ctx.index_attempt_id} "
|
||||
f"cc_pair={self.ctx.cc_pair_id} "
|
||||
f"search_settings={self.ctx.search_settings_id}"
|
||||
)
|
||||
|
||||
# Append extra keyword arguments in logfmt style
|
||||
if kwargs:
|
||||
extra_logfmt = " ".join(f"{key}={value}" for key, value in kwargs.items())
|
||||
msg_final = f"{msg_final} {extra_logfmt}"
|
||||
|
||||
return msg_final
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
@@ -622,6 +496,7 @@ def connector_indexing_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
attempt_found = False
|
||||
n_final_progress: int | None = None
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
@@ -638,21 +513,19 @@ def connector_indexing_task(
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if redis_connector.delete.fenced:
|
||||
raise SimpleJobException(
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because connector deletion is in progress: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.delete.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION.code,
|
||||
f"fence={redis_connector.delete.fence_key}"
|
||||
)
|
||||
|
||||
if redis_connector.stop.fenced:
|
||||
raise SimpleJobException(
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because a connector stop signal was detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.stop.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
|
||||
f"fence={redis_connector.stop.fence_key}"
|
||||
)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
@@ -661,24 +534,19 @@ def connector_indexing_task(
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise SimpleJobException(
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT.code,
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector_index.fenced: # The fence must exist
|
||||
raise SimpleJobException(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector_index.payload # The payload must exist
|
||||
if not payload:
|
||||
raise SimpleJobException(
|
||||
"connector_indexing_task: payload invalid or not found",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
|
||||
)
|
||||
raise ValueError("connector_indexing_task: payload invalid or not found")
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
logger.info(
|
||||
@@ -688,11 +556,10 @@ def connector_indexing_task(
|
||||
continue
|
||||
|
||||
if payload.index_attempt_id != index_attempt_id:
|
||||
raise SimpleJobException(
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - id mismatch. Task may be left over from previous run.: "
|
||||
f"task_index_attempt={index_attempt_id} "
|
||||
f"payload_index_attempt={payload.index_attempt_id}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
|
||||
f"payload_index_attempt={payload.index_attempt_id}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -716,14 +583,7 @@ def connector_indexing_task(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
raise SimpleJobException(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}",
|
||||
code=IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING.code,
|
||||
)
|
||||
return None
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector_index.set_fence(payload)
|
||||
@@ -732,10 +592,10 @@ def connector_indexing_task(
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
raise SimpleJobException(
|
||||
f"Index attempt not found: index_attempt={index_attempt_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
raise ValueError(
|
||||
f"Index attempt not found: index_attempt={index_attempt_id}"
|
||||
)
|
||||
attempt_found = True
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
@@ -743,21 +603,16 @@ def connector_indexing_task(
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
raise SimpleJobException(
|
||||
f"cc_pair not found: cc_pair={cc_pair_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
)
|
||||
raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}")
|
||||
|
||||
if not cc_pair.connector:
|
||||
raise SimpleJobException(
|
||||
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
raise ValueError(
|
||||
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}"
|
||||
)
|
||||
|
||||
if not cc_pair.credential:
|
||||
raise SimpleJobException(
|
||||
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
raise ValueError(
|
||||
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
|
||||
# define a callback class
|
||||
@@ -795,6 +650,20 @@ def connector_indexing_task(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
if attempt_found:
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id, db_session, failure_reason=str(e)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
raise e
|
||||
finally:
|
||||
@@ -809,49 +678,41 @@ def connector_indexing_task(
|
||||
return n_final_progress
|
||||
|
||||
|
||||
def process_job_result(
|
||||
job: SimpleJob,
|
||||
connector_source: str | None,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
log_builder: ConnectorIndexingLogBuilder,
|
||||
) -> SimpleJobResult:
|
||||
result = SimpleJobResult()
|
||||
result.connector_source = connector_source
|
||||
def connector_indexing_task_wrapper(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
is_ee: bool,
|
||||
) -> int | None:
|
||||
"""Just wraps connector_indexing_task so we can log any exceptions before
|
||||
re-raising it."""
|
||||
result: int | None = None
|
||||
|
||||
if job.process:
|
||||
result.exit_code = job.process.exitcode
|
||||
|
||||
if job.status != "error":
|
||||
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
|
||||
return result
|
||||
|
||||
ignore_exitcode = False
|
||||
|
||||
# In EKS, there is an edge case where successful tasks return exit
|
||||
# code 1 in the cloud due to the set_spawn_method not sticking.
|
||||
# We've since worked around this, but the following is a safe way to
|
||||
# work around this issue. Basically, we ignore the job error state
|
||||
# if the completion signal is OK.
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if ignore_exitcode:
|
||||
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...",
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
try:
|
||||
result = connector_indexing_task(
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
tenant_id,
|
||||
is_ee,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"connector_indexing_task exceptioned: "
|
||||
f"tenant={tenant_id} "
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
else:
|
||||
if result.exit_code is not None:
|
||||
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
|
||||
|
||||
result.exception_str = job.exception()
|
||||
# There is a cloud related bug outside of our code
|
||||
# where spawned tasks return with an exit code of 1.
|
||||
# Unfortunately, exceptions also return with an exit code of 1,
|
||||
# so just raising an exception isn't informative
|
||||
# Exiting with 255 makes it possible to distinguish between normal exits
|
||||
# and exceptions.
|
||||
sys.exit(255)
|
||||
|
||||
return result
|
||||
|
||||
@@ -869,32 +730,12 @@ def connector_indexing_proxy_task(
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
|
||||
and forking is inherently unstable.
|
||||
|
||||
To work around this, we use pool=threads and proxy our work to a spawned task.
|
||||
|
||||
TODO(rkuo): refactor this so that there is a single return path where we canonically
|
||||
log the result of running this function.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
result = SimpleJobResult()
|
||||
|
||||
ctx = ConnectorIndexingContext(
|
||||
tenant_id=tenant_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=search_settings_id,
|
||||
index_attempt_id=index_attempt_id,
|
||||
)
|
||||
|
||||
log_builder = ConnectorIndexingLogBuilder(ctx)
|
||||
|
||||
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - starting",
|
||||
mp_start_method=str(multiprocessing.get_start_method()),
|
||||
)
|
||||
f"Indexing watchdog - starting: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"mp_start_method={multiprocessing.get_start_method()}"
|
||||
)
|
||||
|
||||
if not self.request.id:
|
||||
@@ -903,7 +744,7 @@ def connector_indexing_proxy_task(
|
||||
client = SimpleJobClient()
|
||||
|
||||
job = client.submit(
|
||||
connector_indexing_task,
|
||||
connector_indexing_task_wrapper,
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
@@ -913,223 +754,139 @@ def connector_indexing_proxy_task(
|
||||
)
|
||||
|
||||
if not job:
|
||||
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return
|
||||
|
||||
task_logger.info(log_builder.build("Indexing watchdog - spawn succeeded"))
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
if not index_attempt:
|
||||
raise RuntimeError("Index attempt not found")
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
result.connector_source = (
|
||||
index_attempt.connector_credential_pair.connector.source.value
|
||||
)
|
||||
# renew watchdog signal (this has a shorter timeout than set_active)
|
||||
redis_connector_index.set_watchdog(True)
|
||||
|
||||
while True:
|
||||
sleep(5)
|
||||
# renew active signal
|
||||
redis_connector_index.set_active()
|
||||
|
||||
# renew watchdog signal (this has a shorter timeout than set_active)
|
||||
redis_connector_index.set_watchdog(True)
|
||||
# if the job is done, clean up and break
|
||||
if job.done():
|
||||
exit_code: int | None
|
||||
try:
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
|
||||
# renew active signal
|
||||
redis_connector_index.set_active()
|
||||
exit_code = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# if the job is done, clean up and break
|
||||
if job.done():
|
||||
try:
|
||||
result = process_job_result(
|
||||
job, result.connector_source, redis_connector_index, log_builder
|
||||
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
|
||||
# even though logging clearly indicates successful completion
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if not ignore_exitcode:
|
||||
raise RuntimeError("Spawned task exceptioned.")
|
||||
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawned task exceptioned"
|
||||
)
|
||||
)
|
||||
finally:
|
||||
job.release()
|
||||
break
|
||||
|
||||
# if a termination signal is detected, clean up and break
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
log_builder.build("Indexing watchdog - termination signal detected")
|
||||
except Exception:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
result.status = IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL
|
||||
break
|
||||
raise
|
||||
finally:
|
||||
job.release()
|
||||
|
||||
break
|
||||
|
||||
# if a termination signal is detected, clean up and break
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - termination signal detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception looking up index attempt"
|
||||
)
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
except Exception:
|
||||
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
|
||||
result.exception_str = traceback.format_exc()
|
||||
|
||||
# handle exit and reporting
|
||||
elapsed = time.monotonic() - start
|
||||
if result.exception_str is not None:
|
||||
# print with exception
|
||||
job.cancel()
|
||||
break
|
||||
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
failure_reason = (
|
||||
f"Spawned task exceptioned: exit_code={result.exit_code}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
ctx.index_attempt_id,
|
||||
db_session,
|
||||
failure_reason=failure_reason,
|
||||
full_exception_trace=result.exception_str,
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as failed"
|
||||
)
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
normalized_exception_str = "None"
|
||||
if result.exception_str:
|
||||
normalized_exception_str = result.exception_str.replace(
|
||||
"\n", "\\n"
|
||||
).replace('"', '\\"')
|
||||
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
source=result.connector_source,
|
||||
status=result.status.value,
|
||||
exit_code=str(result.exit_code),
|
||||
exception=f'"{normalized_exception_str}"',
|
||||
elapsed=f"{elapsed:.2f}s",
|
||||
)
|
||||
)
|
||||
|
||||
redis_connector_index.set_watchdog(False)
|
||||
raise RuntimeError(f"Exception encountered: traceback={result.exception_str}")
|
||||
|
||||
# print without exception
|
||||
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
except Exception:
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled"
|
||||
)
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
source=result.connector_source,
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
elapsed=f"{elapsed:.2f}s",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
redis_connector_index.set_watchdog(False)
|
||||
return
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
|
||||
soft_time_limit=300,
|
||||
)
|
||||
def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
|
||||
"""Clean up old checkpoints that are older than 7 days."""
|
||||
locked = False
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - finished: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
locked = True
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_attempts = get_index_attempts_with_old_checkpoints(db_session)
|
||||
for attempt in old_attempts:
|
||||
task_logger.info(
|
||||
f"Cleaning up checkpoint for index attempt {attempt.id}"
|
||||
)
|
||||
cleanup_checkpoint_task.apply_async(
|
||||
kwargs={
|
||||
"index_attempt_id": attempt.id,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.CHECKPOINT_CLEANUP,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during checkpoint cleanup")
|
||||
return None
|
||||
finally:
|
||||
if locked:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_checkpoint_cleanup - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLEANUP_CHECKPOINT,
|
||||
bind=True,
|
||||
)
|
||||
def cleanup_checkpoint_task(
|
||||
self: Task, *, index_attempt_id: int, tenant_id: str | None
|
||||
) -> None:
|
||||
"""Clean up a checkpoint for a given index attempt"""
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
cleanup_checkpoint(db_session, index_attempt_id)
|
||||
return
|
||||
|
||||
@@ -240,8 +240,7 @@ def validate_indexing_fence(
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
f"validate_indexing_fence - "
|
||||
f"Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
)
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
@@ -105,7 +105,6 @@ def document_by_cc_pair_cleanup_task(
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
|
||||
80
backend/onyx/background/indexing/checkpointing.py
Normal file
80
backend/onyx/background/indexing/checkpointing.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Experimental functionality related to splitting up indexing
|
||||
into a series of checkpoints to better handle intermittent failures
|
||||
/ jobs being killed by cloud providers."""
|
||||
import datetime
|
||||
|
||||
from onyx.configs.app_configs import EXPERIMENTAL_CHECKPOINTING_ENABLED
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
|
||||
|
||||
|
||||
def _2010_dt() -> datetime.datetime:
|
||||
return datetime.datetime(year=2010, month=1, day=1, tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
def _2020_dt() -> datetime.datetime:
|
||||
return datetime.datetime(year=2020, month=1, day=1, tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
def _default_end_time(
|
||||
last_successful_run: datetime.datetime | None,
|
||||
) -> datetime.datetime:
|
||||
"""If year is before 2010, go to the beginning of 2010.
|
||||
If year is 2010-2020, go in 5 year increments.
|
||||
If year > 2020, then go in 180 day increments.
|
||||
|
||||
For connectors that don't support a `filter_by` and instead rely on `sort_by`
|
||||
for polling, then this will cause a massive duplication of fetches. For these
|
||||
connectors, you may want to override this function to return a more reasonable
|
||||
plan (e.g. extending the 2020+ windows to 6 months, 1 year, or higher)."""
|
||||
last_successful_run = (
|
||||
datetime_to_utc(last_successful_run) if last_successful_run else None
|
||||
)
|
||||
if last_successful_run is None or last_successful_run < _2010_dt():
|
||||
return _2010_dt()
|
||||
|
||||
if last_successful_run < _2020_dt():
|
||||
return min(last_successful_run + datetime.timedelta(days=365 * 5), _2020_dt())
|
||||
|
||||
return last_successful_run + datetime.timedelta(days=180)
|
||||
|
||||
|
||||
def find_end_time_for_indexing_attempt(
|
||||
last_successful_run: datetime.datetime | None,
|
||||
# source_type can be used to override the default for certain connectors, currently unused
|
||||
source_type: DocumentSource,
|
||||
) -> datetime.datetime | None:
|
||||
"""Is the current time unless the connector is run over a large period, in which case it is
|
||||
split up into large time segments that become smaller as it approaches the present
|
||||
"""
|
||||
# NOTE: source_type can be used to override the default for certain connectors
|
||||
end_of_window = _default_end_time(last_successful_run)
|
||||
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
if end_of_window < now:
|
||||
return end_of_window
|
||||
|
||||
# None signals that we should index up to current time
|
||||
return None
|
||||
|
||||
|
||||
def get_time_windows_for_index_attempt(
|
||||
last_successful_run: datetime.datetime, source_type: DocumentSource
|
||||
) -> list[tuple[datetime.datetime, datetime.datetime]]:
|
||||
if not EXPERIMENTAL_CHECKPOINTING_ENABLED:
|
||||
return [(last_successful_run, datetime.datetime.now(tz=datetime.timezone.utc))]
|
||||
|
||||
time_windows: list[tuple[datetime.datetime, datetime.datetime]] = []
|
||||
start_of_window: datetime.datetime | None = last_successful_run
|
||||
while start_of_window:
|
||||
end_of_window = find_end_time_for_indexing_attempt(
|
||||
last_successful_run=start_of_window, source_type=source_type
|
||||
)
|
||||
time_windows.append(
|
||||
(
|
||||
start_of_window,
|
||||
end_of_window or datetime.datetime.now(tz=datetime.timezone.utc),
|
||||
)
|
||||
)
|
||||
start_of_window = end_of_window
|
||||
|
||||
return time_windows
|
||||
@@ -1,200 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from io import BytesIO
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.object_size_check import deep_getsizeof
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_NUM_RECENT_ATTEMPTS_TO_CONSIDER = 20
|
||||
_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT = 100
|
||||
|
||||
|
||||
def _build_checkpoint_pointer(index_attempt_id: int) -> str:
|
||||
return f"checkpoint_{index_attempt_id}.json"
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
db_session: Session, index_attempt_id: int, checkpoint: ConnectorCheckpoint
|
||||
) -> str:
|
||||
"""Save a checkpoint for a given index attempt to the file store"""
|
||||
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.save_file(
|
||||
file_name=checkpoint_pointer,
|
||||
content=BytesIO(checkpoint.model_dump_json().encode()),
|
||||
display_name=checkpoint_pointer,
|
||||
file_origin=FileOrigin.INDEXING_CHECKPOINT,
|
||||
file_type="application/json",
|
||||
)
|
||||
|
||||
index_attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
index_attempt.checkpoint_pointer = checkpoint_pointer
|
||||
db_session.add(index_attempt)
|
||||
db_session.commit()
|
||||
return checkpoint_pointer
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
db_session: Session, index_attempt_id: int
|
||||
) -> ConnectorCheckpoint | None:
|
||||
"""Load a checkpoint for a given index attempt from the file store"""
|
||||
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
|
||||
file_store = get_default_file_store(db_session)
|
||||
try:
|
||||
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
|
||||
checkpoint_data = checkpoint_io.read().decode("utf-8")
|
||||
return ConnectorCheckpoint.model_validate_json(checkpoint_data)
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
|
||||
def get_latest_valid_checkpoint(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
) -> ConnectorCheckpoint:
|
||||
"""Get the latest valid checkpoint for a given connector credential pair"""
|
||||
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=search_settings_id,
|
||||
db_session=db_session,
|
||||
limit=_NUM_RECENT_ATTEMPTS_TO_CONSIDER,
|
||||
)
|
||||
checkpoint_candidates = [
|
||||
candidate
|
||||
for candidate in checkpoint_candidates
|
||||
if (
|
||||
candidate.poll_range_start == window_start
|
||||
and candidate.poll_range_end == window_end
|
||||
and candidate.status == IndexingStatus.FAILED
|
||||
and candidate.checkpoint_pointer is not None
|
||||
# we want to make sure that the checkpoint is actually useful
|
||||
# if it's only gone through a few docs, it's probably not worth
|
||||
# using. This also avoids weird cases where a connector is basically
|
||||
# non-functional but still "makes progress" by slowly moving the
|
||||
# checkpoint forward run after run
|
||||
and candidate.total_docs_indexed
|
||||
and candidate.total_docs_indexed > _NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT
|
||||
)
|
||||
]
|
||||
|
||||
# don't keep using checkpoints if we've had a bunch of failed attempts in a row
|
||||
# for now, capped at 10
|
||||
if len(checkpoint_candidates) == _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
|
||||
logger.warning(
|
||||
f"{_NUM_RECENT_ATTEMPTS_TO_CONSIDER} consecutive failed attempts found "
|
||||
f"for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
|
||||
"from scratch."
|
||||
)
|
||||
return ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
|
||||
# assumes latest checkpoint is the furthest along. This only isn't true
|
||||
# if something else has gone wrong.
|
||||
latest_valid_checkpoint_candidate = (
|
||||
checkpoint_candidates[0] if checkpoint_candidates else None
|
||||
)
|
||||
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
if latest_valid_checkpoint_candidate:
|
||||
try:
|
||||
previous_checkpoint = load_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=latest_valid_checkpoint_candidate.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to load checkpoint from previous failed attempt with ID "
|
||||
f"{latest_valid_checkpoint_candidate.id}."
|
||||
)
|
||||
previous_checkpoint = None
|
||||
|
||||
if previous_checkpoint is not None:
|
||||
logger.info(
|
||||
f"Using checkpoint from previous failed attempt with ID "
|
||||
f"{latest_valid_checkpoint_candidate.id}. Previous checkpoint: "
|
||||
f"{previous_checkpoint}"
|
||||
)
|
||||
save_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=latest_valid_checkpoint_candidate.id,
|
||||
checkpoint=previous_checkpoint,
|
||||
)
|
||||
checkpoint = previous_checkpoint
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def get_index_attempts_with_old_checkpoints(
|
||||
db_session: Session, days_to_keep: int = 7
|
||||
) -> list[IndexAttempt]:
|
||||
"""Get all index attempts with checkpoints older than the specified number of days.
|
||||
|
||||
Args:
|
||||
db_session: The database session
|
||||
days_to_keep: Number of days to keep checkpoints for (default: 7)
|
||||
|
||||
Returns:
|
||||
Number of checkpoints deleted
|
||||
"""
|
||||
cutoff_date = get_db_current_time(db_session) - timedelta(days=days_to_keep)
|
||||
|
||||
# Find all index attempts with checkpoints older than cutoff_date
|
||||
old_attempts = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
and_(
|
||||
IndexAttempt.checkpoint_pointer.isnot(None),
|
||||
IndexAttempt.time_created < cutoff_date,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return old_attempts
|
||||
|
||||
|
||||
def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
|
||||
"""Clean up a checkpoint for a given index attempt"""
|
||||
index_attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
|
||||
if not index_attempt.checkpoint_pointer:
|
||||
return None
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.delete_file(index_attempt.checkpoint_pointer)
|
||||
|
||||
index_attempt.checkpoint_pointer = None
|
||||
db_session.add(index_attempt)
|
||||
db_session.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None:
|
||||
"""Check if the checkpoint content size exceeds the limit (200MB)"""
|
||||
content_size = deep_getsizeof(checkpoint.checkpoint_content)
|
||||
if content_size > 200_000_000: # 200MB in bytes
|
||||
raise ValueError(
|
||||
f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit"
|
||||
)
|
||||
@@ -5,8 +5,6 @@ not follow the expected behavior, etc.
|
||||
NOTE: cannot use Celery directly due to
|
||||
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.context import SpawnProcess
|
||||
@@ -20,16 +18,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class SimpleJobException(Exception):
|
||||
"""lets us raise an exception that will return a specific error code"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
code: int | None = kwargs.pop("code", None)
|
||||
self.code = code
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
JobStatusType = (
|
||||
Literal["error"]
|
||||
| Literal["finished"]
|
||||
@@ -40,10 +28,7 @@ JobStatusType = (
|
||||
|
||||
|
||||
def _initializer(
|
||||
func: Callable,
|
||||
queue: mp.Queue,
|
||||
args: list | tuple,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Initialize the child process with a fresh SQLAlchemy Engine.
|
||||
|
||||
@@ -67,29 +52,13 @@ def _initializer(
|
||||
)
|
||||
|
||||
# Proceed with executing the target function
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except SimpleJobException as e:
|
||||
logger.exception("SimpleJob raised a SimpleJobException")
|
||||
error_msg = traceback.format_exc()
|
||||
queue.put(error_msg) # Send the exception to the parent process
|
||||
|
||||
sys.exit(e.code) # use the given exit code
|
||||
except Exception:
|
||||
logger.exception("SimpleJob raised an exception")
|
||||
error_msg = traceback.format_exc()
|
||||
queue.put(error_msg) # Send the exception to the parent process
|
||||
|
||||
sys.exit(255) # use 255 to indicate a generic exception
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def _run_in_process(
|
||||
func: Callable,
|
||||
queue: mp.Queue,
|
||||
args: list | tuple,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
_initializer(func, queue, args, kwargs)
|
||||
_initializer(func, args, kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -98,8 +67,6 @@ class SimpleJob:
|
||||
|
||||
id: int
|
||||
process: Optional["SpawnProcess"] = None
|
||||
queue: Optional[mp.Queue] = None
|
||||
_exception: Optional[str] = None
|
||||
|
||||
def cancel(self) -> bool:
|
||||
return self.release()
|
||||
@@ -133,15 +100,9 @@ class SimpleJob:
|
||||
def exception(self) -> str:
|
||||
"""Needed to match the Dask API, but not implemented since we don't currently
|
||||
have a way to get back the exception information from the child process."""
|
||||
|
||||
"""Retrieve exception from the multiprocessing queue if available."""
|
||||
if self._exception is None and self.queue and not self.queue.empty():
|
||||
self._exception = self.queue.get() # Get exception from queue
|
||||
|
||||
if self._exception:
|
||||
return self._exception
|
||||
|
||||
return f"Job with ID '{self.id}' did not report an exception."
|
||||
return (
|
||||
f"Job with ID '{self.id}' was killed or encountered an unhandled exception."
|
||||
)
|
||||
|
||||
|
||||
class SimpleJobClient:
|
||||
@@ -176,11 +137,8 @@ class SimpleJobClient:
|
||||
# this approach allows us to always "spawn" a new process regardless of
|
||||
# get_start_method's current setting
|
||||
ctx = mp.get_context("spawn")
|
||||
queue = ctx.Queue()
|
||||
process = ctx.Process(
|
||||
target=_run_in_process, args=(func, queue, args), daemon=True
|
||||
)
|
||||
job = SimpleJob(id=job_id, process=process, queue=queue)
|
||||
process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True)
|
||||
job = SimpleJob(id=job_id, process=process)
|
||||
process.start()
|
||||
|
||||
self.jobs[job_id] = job
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
import tracemalloc
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DANSWER_TRACEMALLOC_FRAMES = 10
|
||||
|
||||
|
||||
class MemoryTracer:
|
||||
def __init__(self, interval: int = 0, num_print_entries: int = 5):
|
||||
self.interval = interval
|
||||
self.num_print_entries = num_print_entries
|
||||
self.snapshot_first: tracemalloc.Snapshot | None = None
|
||||
self.snapshot_prev: tracemalloc.Snapshot | None = None
|
||||
self.snapshot: tracemalloc.Snapshot | None = None
|
||||
self.counter = 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the memory tracer if interval is greater than 0."""
|
||||
if self.interval > 0:
|
||||
logger.debug(f"Memory tracer starting: interval={self.interval}")
|
||||
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
|
||||
self._take_snapshot()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the memory tracer if it's running."""
|
||||
if self.interval > 0:
|
||||
self.log_final_diff()
|
||||
tracemalloc.stop()
|
||||
logger.debug("Memory tracer stopped.")
|
||||
|
||||
def _take_snapshot(self) -> None:
|
||||
"""Take a snapshot and update internal snapshot states."""
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
# Filter out irrelevant frames
|
||||
snapshot = snapshot.filter_traces(
|
||||
(
|
||||
tracemalloc.Filter(False, tracemalloc.__file__),
|
||||
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
|
||||
tracemalloc.Filter(False, "<frozen importlib._bootstrap_external>"),
|
||||
)
|
||||
)
|
||||
|
||||
if not self.snapshot_first:
|
||||
self.snapshot_first = snapshot
|
||||
|
||||
if self.snapshot:
|
||||
self.snapshot_prev = self.snapshot
|
||||
|
||||
self.snapshot = snapshot
|
||||
|
||||
def _log_diff(
|
||||
self, current: tracemalloc.Snapshot, previous: tracemalloc.Snapshot
|
||||
) -> None:
|
||||
"""Log the memory difference between two snapshots."""
|
||||
stats = current.compare_to(previous, "traceback")
|
||||
for s in stats[: self.num_print_entries]:
|
||||
logger.debug(f"Tracer diff: {s}")
|
||||
for line in s.traceback.format():
|
||||
logger.debug(f"* {line}")
|
||||
|
||||
def increment_and_maybe_trace(self) -> None:
|
||||
"""Increment counter and perform trace if interval is hit."""
|
||||
if self.interval <= 0:
|
||||
return
|
||||
|
||||
self.counter += 1
|
||||
if self.counter % self.interval == 0:
|
||||
logger.debug(
|
||||
f"Running trace comparison for batch {self.counter}. interval={self.interval}"
|
||||
)
|
||||
self._take_snapshot()
|
||||
if self.snapshot and self.snapshot_prev:
|
||||
self._log_diff(self.snapshot, self.snapshot_prev)
|
||||
|
||||
def log_final_diff(self) -> None:
|
||||
"""Log the final memory diff between start and end of indexing."""
|
||||
if self.interval <= 0:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Running trace comparison between start and end of indexing. {self.counter} batches processed."
|
||||
)
|
||||
self._take_snapshot()
|
||||
if self.snapshot and self.snapshot_first:
|
||||
self._log_diff(self.snapshot, self.snapshot_first)
|
||||
@@ -1,40 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.models import IndexAttemptError
|
||||
|
||||
|
||||
class IndexAttemptErrorPydantic(BaseModel):
|
||||
id: int
|
||||
connector_credential_pair_id: int
|
||||
|
||||
document_id: str | None
|
||||
document_link: str | None
|
||||
|
||||
entity_id: str | None
|
||||
failed_time_range_start: datetime | None
|
||||
failed_time_range_end: datetime | None
|
||||
|
||||
failure_message: str
|
||||
is_resolved: bool = False
|
||||
|
||||
time_created: datetime
|
||||
|
||||
index_attempt_id: int
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic":
|
||||
return cls(
|
||||
id=model.id,
|
||||
connector_credential_pair_id=model.connector_credential_pair_id,
|
||||
document_id=model.document_id,
|
||||
document_link=model.document_link,
|
||||
entity_id=model.entity_id,
|
||||
failed_time_range_start=model.failed_time_range_start,
|
||||
failed_time_range_end=model.failed_time_range_end,
|
||||
failure_message=model.failure_message,
|
||||
is_resolved=model.is_resolved,
|
||||
time_created=model.time_created,
|
||||
index_attempt_id=model.index_attempt_id,
|
||||
)
|
||||
@@ -1,6 +1,5 @@
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -8,11 +7,8 @@ from datetime import timezone
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.indexing.checkpointing_utils import check_checkpoint_size
|
||||
from onyx.background.indexing.checkpointing_utils import get_latest_valid_checkpoint
|
||||
from onyx.background.indexing.checkpointing_utils import save_checkpoint
|
||||
from onyx.background.indexing.memory_tracer import MemoryTracer
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
from onyx.background.indexing.tracer import OnyxTracer
|
||||
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
|
||||
@@ -21,8 +17,6 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -30,18 +24,15 @@ from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from onyx.db.index_attempt import mark_attempt_succeeded
|
||||
from onyx.db.index_attempt import transition_attempt_to_in_progress
|
||||
from onyx.db.index_attempt import update_docs_indexed
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
@@ -62,7 +53,6 @@ INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
|
||||
def _get_connector_runner(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
batch_size: int,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
tenant_id: str | None,
|
||||
@@ -110,9 +100,7 @@ def _get_connector_runner(
|
||||
raise e
|
||||
|
||||
return ConnectorRunner(
|
||||
connector=runnable_connector,
|
||||
batch_size=batch_size,
|
||||
time_range=(start_time, end_time),
|
||||
connector=runnable_connector, time_range=(start_time, end_time)
|
||||
)
|
||||
|
||||
|
||||
@@ -171,66 +159,6 @@ class RunIndexingContext(BaseModel):
|
||||
search_settings_status: IndexModelStatus
|
||||
|
||||
|
||||
def _check_connector_and_attempt_status(
|
||||
db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int
|
||||
) -> None:
|
||||
"""
|
||||
Checks the status of the connector credential pair and index attempt.
|
||||
Raises a RuntimeError if any conditions are not met.
|
||||
"""
|
||||
cc_pair_loop = get_connector_credential_pair_from_id(
|
||||
db_session_temp,
|
||||
ctx.cc_pair_id,
|
||||
)
|
||||
if not cc_pair_loop:
|
||||
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
|
||||
|
||||
if (
|
||||
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and ctx.search_settings_status != IndexModelStatus.FUTURE
|
||||
) or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING:
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
|
||||
index_attempt_loop = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt_loop:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
|
||||
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
|
||||
raise RuntimeError(
|
||||
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
|
||||
)
|
||||
|
||||
|
||||
def _check_failure_threshold(
|
||||
total_failures: int,
|
||||
document_count: int,
|
||||
batch_num: int,
|
||||
last_failure: ConnectorFailure | None,
|
||||
) -> None:
|
||||
"""Check if we've hit the failure threshold and raise an appropriate exception if so.
|
||||
|
||||
We consider the threshold hit if:
|
||||
1. We have more than 3 failures AND
|
||||
2. Failures account for more than 10% of processed documents
|
||||
"""
|
||||
failure_ratio = total_failures / (document_count or 1)
|
||||
|
||||
FAILURE_THRESHOLD = 3
|
||||
FAILURE_RATIO_THRESHOLD = 0.1
|
||||
if total_failures > FAILURE_THRESHOLD and failure_ratio > FAILURE_RATIO_THRESHOLD:
|
||||
logger.error(
|
||||
f"Connector run failed with '{total_failures}' errors "
|
||||
f"after '{batch_num}' batches."
|
||||
)
|
||||
if last_failure and last_failure.exception:
|
||||
raise last_failure.exception from last_failure.exception
|
||||
|
||||
raise RuntimeError(
|
||||
f"Connector run encountered too many errors, aborting. "
|
||||
f"Last error: {last_failure}"
|
||||
)
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
@@ -241,8 +169,11 @@ def _run_indexing(
|
||||
1. Get documents which are either new or updated from specified application
|
||||
2. Embed and index these documents into the chosen datastore (vespa)
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
|
||||
TODO: do not change index attempt statuses here ... instead, set signals in redis
|
||||
and allow the monitor function to clean them up
|
||||
"""
|
||||
start_time = time.monotonic() # jsut used for logging
|
||||
start_time = time.time()
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
@@ -290,46 +221,6 @@ def _run_indexing(
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
)
|
||||
if last_successful_index_time > POLL_CONNECTOR_OFFSET:
|
||||
window_start = datetime.fromtimestamp(
|
||||
last_successful_index_time, tz=timezone.utc
|
||||
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
|
||||
else:
|
||||
# don't go into "negative" time if we've never indexed before
|
||||
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
|
||||
most_recent_attempt = next(
|
||||
iter(
|
||||
get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt_start.search_settings_id,
|
||||
db_session=db_session_temp,
|
||||
limit=1,
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
# if the last attempt failed, try and use the same window. This is necessary
|
||||
# to ensure correctness with checkpointing. If we don't do this, things like
|
||||
# new slack channels could be missed (since existing slack channels are
|
||||
# cached as part of the checkpoint).
|
||||
if (
|
||||
most_recent_attempt
|
||||
and most_recent_attempt.poll_range_end
|
||||
and (
|
||||
most_recent_attempt.status == IndexingStatus.FAILED
|
||||
or most_recent_attempt.status == IndexingStatus.CANCELED
|
||||
)
|
||||
):
|
||||
window_end = most_recent_attempt.poll_range_end
|
||||
else:
|
||||
window_end = datetime.now(tz=timezone.utc)
|
||||
|
||||
# add start/end now that they have been set
|
||||
index_attempt_start.poll_range_start = window_start
|
||||
index_attempt_start.poll_range_end = window_end
|
||||
db_session_temp.add(index_attempt_start)
|
||||
db_session_temp.commit()
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=index_attempt_start.search_settings,
|
||||
@@ -343,6 +234,7 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
attempt_id=index_attempt_id,
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=(
|
||||
@@ -354,73 +246,63 @@ def _run_indexing(
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# Initialize memory tracer. NOTE: won't actually do anything if
|
||||
# `INDEXING_TRACER_INTERVAL` is 0.
|
||||
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
|
||||
memory_tracer.start()
|
||||
tracer: OnyxTracer
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
|
||||
tracer = OnyxTracer()
|
||||
tracer.start()
|
||||
tracer.snap()
|
||||
|
||||
index_attempt_md = IndexAttemptMetadata(
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
)
|
||||
|
||||
total_failures = 0
|
||||
batch_num = 0
|
||||
net_doc_change = 0
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
run_end_dt = None
|
||||
tracer_counter: int
|
||||
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session_temp,
|
||||
attempt=index_attempt,
|
||||
batch_size=INDEX_BATCH_SIZE,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
for ind, (window_start, window_end) in enumerate(
|
||||
get_time_windows_for_index_attempt(
|
||||
last_successful_run=datetime.fromtimestamp(
|
||||
last_successful_index_time, tz=timezone.utc
|
||||
),
|
||||
source_type=db_connector.source,
|
||||
)
|
||||
):
|
||||
cc_pair_loop: ConnectorCredentialPair | None = None
|
||||
index_attempt_loop: IndexAttempt | None = None
|
||||
tracer_counter = 0
|
||||
|
||||
try:
|
||||
window_start = max(
|
||||
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
# the beginning in order to avoid weird interactions between
|
||||
# checkpointing / failure handling.
|
||||
if index_attempt.from_beginning:
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
else:
|
||||
checkpoint = get_latest_valid_checkpoint(
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt_loop_start = get_index_attempt(
|
||||
db_session_temp, index_attempt_id
|
||||
)
|
||||
if not index_attempt_loop_start:
|
||||
raise RuntimeError(
|
||||
f"Index attempt {index_attempt_id} not found in DB."
|
||||
)
|
||||
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session_temp,
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
attempt=index_attempt_loop_start,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
unresolved_errors = get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
unresolved_only=True,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
doc_id_to_unresolved_errors: dict[
|
||||
str, list[IndexAttemptError]
|
||||
] = defaultdict(list)
|
||||
for error in unresolved_errors:
|
||||
if error.document_id:
|
||||
doc_id_to_unresolved_errors[error.document_id].append(error)
|
||||
|
||||
entity_based_unresolved_errors = [
|
||||
error for error in unresolved_errors if error.entity_id
|
||||
]
|
||||
|
||||
while checkpoint.has_more:
|
||||
logger.info(
|
||||
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
|
||||
)
|
||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||
checkpoint
|
||||
):
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.snap()
|
||||
for doc_batch in connector_runner.run():
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
@@ -431,37 +313,41 @@ def _run_indexing(
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
# will exception if the connector/index attempt is marked as paused/failed
|
||||
_check_connector_and_attempt_status(
|
||||
db_session_temp, ctx, index_attempt_id
|
||||
cc_pair_loop = get_connector_credential_pair_from_id(
|
||||
db_session_temp,
|
||||
ctx.cc_pair_id,
|
||||
)
|
||||
if not cc_pair_loop:
|
||||
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
|
||||
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
total_failures += 1
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
failure,
|
||||
db_session_temp,
|
||||
if (
|
||||
(
|
||||
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and ctx.search_settings_status != IndexModelStatus.FUTURE
|
||||
)
|
||||
# if it's deleting, we don't care if this is a secondary index
|
||||
or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING
|
||||
):
|
||||
# let the `except` block handle this
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
|
||||
index_attempt_loop = get_index_attempt(
|
||||
db_session_temp, index_attempt_id
|
||||
)
|
||||
if not index_attempt_loop:
|
||||
raise RuntimeError(
|
||||
f"Index attempt {index_attempt_id} not found in DB."
|
||||
)
|
||||
|
||||
_check_failure_threshold(
|
||||
total_failures, document_count, batch_num, failure
|
||||
)
|
||||
|
||||
# save the new checkpoint (if one is provided)
|
||||
if next_checkpoint:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
# below is all document processing logic, so if no batch we can just continue
|
||||
if document_batch is None:
|
||||
continue
|
||||
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
|
||||
# Likely due to user manually disabling it or model swap
|
||||
raise RuntimeError(
|
||||
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
|
||||
)
|
||||
|
||||
batch_description = []
|
||||
|
||||
doc_batch_cleaned = strip_null_characters(document_batch)
|
||||
doc_batch_cleaned = strip_null_characters(doc_batch)
|
||||
for doc in doc_batch_cleaned:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
@@ -491,51 +377,15 @@ def _run_indexing(
|
||||
chunk_count += index_pipeline_result.total_chunks
|
||||
document_count += index_pipeline_result.total_docs
|
||||
|
||||
# resolve errors for documents that were successfully indexed
|
||||
failed_document_ids = [
|
||||
failure.failed_document.document_id
|
||||
for failure in index_pipeline_result.failures
|
||||
if failure.failed_document
|
||||
]
|
||||
successful_document_ids = [
|
||||
document.id
|
||||
for document in document_batch
|
||||
if document.id not in failed_document_ids
|
||||
]
|
||||
for document_id in successful_document_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
if document_id in doc_id_to_unresolved_errors:
|
||||
logger.info(
|
||||
f"Resolving IndexAttemptError for document '{document_id}'"
|
||||
)
|
||||
for error in doc_id_to_unresolved_errors[document_id]:
|
||||
error.is_resolved = True
|
||||
db_session_temp.add(error)
|
||||
db_session_temp.commit()
|
||||
|
||||
# add brand new failures
|
||||
if index_pipeline_result.failures:
|
||||
total_failures += len(index_pipeline_result.failures)
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
for failure in index_pipeline_result.failures:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
failure,
|
||||
db_session_temp,
|
||||
)
|
||||
|
||||
_check_failure_threshold(
|
||||
total_failures,
|
||||
document_count,
|
||||
batch_num,
|
||||
index_pipeline_result.failures[-1],
|
||||
)
|
||||
# commit transaction so that the `update` below begins
|
||||
# with a brand new transaction. Postgres uses the start
|
||||
# of the transactions when computing `NOW()`, so if we have
|
||||
# a long running transaction, the `time_updated` field will
|
||||
# be inaccurate
|
||||
db_session.commit()
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
|
||||
# so we need either to commit() or to use a new session
|
||||
update_docs_indexed(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
@@ -547,77 +397,126 @@ def _run_indexing(
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
tracer_counter += 1
|
||||
if (
|
||||
INDEXING_TRACER_INTERVAL > 0
|
||||
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
|
||||
):
|
||||
logger.debug(
|
||||
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
|
||||
)
|
||||
tracer.snap()
|
||||
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
|
||||
|
||||
# `make sure the checkpoints aren't getting too large`at some regular interval
|
||||
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
|
||||
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
|
||||
check_checkpoint_size(checkpoint)
|
||||
run_end_dt = window_end
|
||||
if ctx.is_primary:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
|
||||
# save latest checkpoint
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
save_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Connector run exceptioned after elapsed time: "
|
||||
f"{time.monotonic() - start_time} seconds"
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
# Otherwise, some progress was made - the next run will not start from the beginning.
|
||||
# In this case, it is not accurate to mark it as a failure. When the next run begins,
|
||||
# if that fails immediately, it will be marked as a failure.
|
||||
#
|
||||
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or (
|
||||
cc_pair_loop is not None and not cc_pair_loop.status.is_active()
|
||||
)
|
||||
or (
|
||||
index_attempt_loop is not None
|
||||
and index_attempt_loop.status != IndexingStatus.IN_PROGRESS
|
||||
)
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
|
||||
# break => similar to success case. As mentioned above, if the next run fails for the same
|
||||
# reason it will then be marked as a failure
|
||||
break
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
logger.debug(
|
||||
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
|
||||
)
|
||||
tracer.snap()
|
||||
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
|
||||
tracer.stop()
|
||||
logger.debug("Memory tracer stopped.")
|
||||
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
if (
|
||||
index_attempt_md.num_exceptions > 0
|
||||
and index_attempt_md.num_exceptions >= batch_num
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason="All batches exceptioned.",
|
||||
)
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
)
|
||||
raise Exception(
|
||||
f"Connector failed - All batches exceptioned: batches={batch_num}"
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
|
||||
memory_tracer.stop()
|
||||
|
||||
elapsed_time = time.monotonic() - start_time
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
# resolve entity-based errors
|
||||
for error in entity_based_unresolved_errors:
|
||||
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
|
||||
error.is_resolved = True
|
||||
db_session_temp.add(error)
|
||||
db_session_temp.commit()
|
||||
|
||||
if total_failures == 0:
|
||||
if index_attempt_md.num_exceptions == 0:
|
||||
mark_attempt_succeeded(index_attempt_id, db_session_temp)
|
||||
|
||||
create_milestone_and_report(
|
||||
@@ -636,7 +535,7 @@ def _run_indexing(
|
||||
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
|
||||
logger.info(
|
||||
f"Connector completed with some errors: "
|
||||
f"failures={total_failures} "
|
||||
f"exceptions={index_attempt_md.num_exceptions} "
|
||||
f"batches={batch_num} "
|
||||
f"docs={document_count} "
|
||||
f"chunks={chunk_count} "
|
||||
@@ -648,7 +547,7 @@ def _run_indexing(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
run_dt=window_end,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
|
||||
|
||||
@@ -659,43 +558,46 @@ def run_indexing_entrypoint(
|
||||
is_ee: bool = False,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
"""Don't swallow exceptions here ... propagate them up."""
|
||||
try:
|
||||
if is_ee:
|
||||
global_version.set_ee()
|
||||
|
||||
if is_ee:
|
||||
global_version.set_ee()
|
||||
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
TaskAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# TODO: remove long running session entirely
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
tenant_str = ""
|
||||
if tenant_id is not None:
|
||||
tenant_str = f" for tenant {tenant_id}"
|
||||
|
||||
connector_name = attempt.connector_credential_pair.connector.name
|
||||
connector_config = (
|
||||
attempt.connector_credential_pair.connector.connector_specific_config
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
TaskAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
credential_id = attempt.connector_credential_pair.credential_id
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# TODO: remove long running session entirely
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
logger.info(
|
||||
f"Indexing starting{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
tenant_str = ""
|
||||
if tenant_id is not None:
|
||||
tenant_str = f" for tenant {tenant_id}"
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
|
||||
connector_name = attempt.connector_credential_pair.connector.name
|
||||
connector_config = (
|
||||
attempt.connector_credential_pair.connector.connector_specific_config
|
||||
)
|
||||
credential_id = attempt.connector_credential_pair.credential_id
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
logger.info(
|
||||
f"Indexing starting{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
|
||||
)
|
||||
|
||||
77
backend/onyx/background/indexing/tracer.py
Normal file
77
backend/onyx/background/indexing/tracer.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import tracemalloc
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DANSWER_TRACEMALLOC_FRAMES = 10
|
||||
|
||||
|
||||
class OnyxTracer:
|
||||
def __init__(self) -> None:
|
||||
self.snapshot_first: tracemalloc.Snapshot | None = None
|
||||
self.snapshot_prev: tracemalloc.Snapshot | None = None
|
||||
self.snapshot: tracemalloc.Snapshot | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
|
||||
|
||||
def stop(self) -> None:
|
||||
tracemalloc.stop()
|
||||
|
||||
def snap(self) -> None:
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
# Filter out irrelevant frames (e.g., from tracemalloc itself or importlib)
|
||||
snapshot = snapshot.filter_traces(
|
||||
(
|
||||
tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc
|
||||
tracemalloc.Filter(
|
||||
False, "<frozen importlib._bootstrap>"
|
||||
), # Exclude importlib
|
||||
tracemalloc.Filter(
|
||||
False, "<frozen importlib._bootstrap_external>"
|
||||
), # Exclude external importlib
|
||||
)
|
||||
)
|
||||
|
||||
if not self.snapshot_first:
|
||||
self.snapshot_first = snapshot
|
||||
|
||||
if self.snapshot:
|
||||
self.snapshot_prev = self.snapshot
|
||||
|
||||
self.snapshot = snapshot
|
||||
|
||||
def log_snapshot(self, numEntries: int) -> None:
|
||||
if not self.snapshot:
|
||||
return
|
||||
|
||||
stats = self.snapshot.statistics("traceback")
|
||||
for s in stats[:numEntries]:
|
||||
logger.debug(f"Tracer snap: {s}")
|
||||
for line in s.traceback:
|
||||
logger.debug(f"* {line}")
|
||||
|
||||
@staticmethod
|
||||
def log_diff(
|
||||
snap_current: tracemalloc.Snapshot,
|
||||
snap_previous: tracemalloc.Snapshot,
|
||||
numEntries: int,
|
||||
) -> None:
|
||||
stats = snap_current.compare_to(snap_previous, "traceback")
|
||||
for s in stats[:numEntries]:
|
||||
logger.debug(f"Tracer diff: {s}")
|
||||
for line in s.traceback.format():
|
||||
logger.debug(f"* {line}")
|
||||
|
||||
def log_previous_diff(self, numEntries: int) -> None:
|
||||
if not self.snapshot or not self.snapshot_prev:
|
||||
return
|
||||
|
||||
OnyxTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries)
|
||||
|
||||
def log_first_diff(self, numEntries: int) -> None:
|
||||
if not self.snapshot or not self.snapshot_first:
|
||||
return
|
||||
|
||||
OnyxTracer.log_diff(self.snapshot, self.snapshot_first, numEntries)
|
||||
@@ -27,10 +27,8 @@ from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.gpu_utils import gpu_status_request
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -82,26 +80,6 @@ class Answer:
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
rerank_settings = search_request.rerank_settings
|
||||
|
||||
using_cloud_reranking = (
|
||||
rerank_settings is not None
|
||||
and rerank_settings.rerank_provider_type is not None
|
||||
)
|
||||
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
|
||||
|
||||
# TODO: this is a hack to force the query to be used for the search tool
|
||||
# this should be removed once we fully unify graph inputs (i.e.
|
||||
# remove SearchQuery entirely)
|
||||
if (
|
||||
force_use_tool.force_use
|
||||
and search_tool
|
||||
and force_use_tool.args
|
||||
and force_use_tool.tool_name == search_tool.name
|
||||
and QUERY_FIELD in force_use_tool.args
|
||||
):
|
||||
search_request.query = force_use_tool.args[QUERY_FIELD]
|
||||
|
||||
self.graph_inputs = GraphInputs(
|
||||
search_request=search_request,
|
||||
prompt_builder=prompt_builder,
|
||||
@@ -116,6 +94,7 @@ class Answer:
|
||||
force_use_tool=force_use_tool,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
)
|
||||
assert db_session, "db_session must be provided for agentic persistence"
|
||||
self.graph_persistence = GraphPersistence(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -125,7 +104,6 @@ class Answer:
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
allow_refinement=True,
|
||||
allow_agent_reranking=allow_agent_reranking,
|
||||
)
|
||||
self.graph_config = GraphConfig(
|
||||
inputs=self.graph_inputs,
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import ToolCallException
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
|
||||
@@ -31,9 +31,22 @@ AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
|
||||
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
|
||||
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
|
||||
|
||||
AGENT_ANSWER_GENERATION_BY_FAST_LLM = (
|
||||
os.environ.get("AGENT_ANSWER_GENERATION_BY_FAST_LLM", "").lower() == "true"
|
||||
)
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = 30 # in seconds
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = 10 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = 25 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = 1 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = 3 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = 12 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = 25 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = 8 # in seconds
|
||||
|
||||
|
||||
AGENT_RETRIEVAL_STATS = (
|
||||
not os.environ.get("AGENT_RETRIEVAL_STATS") == "False"
|
||||
@@ -165,172 +178,80 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
|
||||
) # 2000
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = 10 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
) # 3
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 4 # in seconds
|
||||
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION
|
||||
) # 30
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
|
||||
) # 12
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 3 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
) # 6
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 3 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION
|
||||
) # 1
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 1 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = 3 # in seconds
|
||||
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION
|
||||
) # 4
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
|
||||
) # 8
|
||||
|
||||
GRAPH_VERSION_NAME: str = "a"
|
||||
|
||||
@@ -169,11 +169,6 @@ POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
POSTGRES_API_SERVER_POOL_OVERFLOW = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10
|
||||
)
|
||||
|
||||
# defaults to False
|
||||
# generally should only be used for
|
||||
POSTGRES_USE_NULL_POOL = os.environ.get("POSTGRES_USE_NULL_POOL", "").lower() == "true"
|
||||
|
||||
# defaults to False
|
||||
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
|
||||
|
||||
@@ -626,8 +621,6 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH")
|
||||
|
||||
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
|
||||
|
||||
# Set to true to mock LLM responses for testing purposes
|
||||
|
||||
@@ -165,9 +165,6 @@ class DocumentSource(str, Enum):
|
||||
EGNYTE = "egnyte"
|
||||
AIRTABLE = "airtable"
|
||||
|
||||
# Special case just for integration tests
|
||||
MOCK_CONNECTOR = "mock_connector"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
|
||||
@@ -246,7 +243,6 @@ class FileOrigin(str, Enum):
|
||||
CHAT_IMAGE_GEN = "chat_image_gen"
|
||||
CONNECTOR = "connector"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
INDEXING_CHECKPOINT = "indexing_checkpoint"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
@@ -278,7 +274,6 @@ class OnyxCeleryQueues:
|
||||
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
LLM_MODEL_UPDATE = "llm_model_update"
|
||||
CHECKPOINT_CLEANUP = "checkpoint_cleanup"
|
||||
|
||||
# Heavy queue
|
||||
CONNECTOR_PRUNING = "connector_pruning"
|
||||
@@ -298,7 +293,6 @@ class OnyxRedisLocks:
|
||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
|
||||
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
|
||||
CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK = "da_lock:check_checkpoint_cleanup_beat"
|
||||
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_doc_permissions_sync_beat"
|
||||
)
|
||||
@@ -374,10 +368,6 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
|
||||
# Connector checkpoint cleanup
|
||||
CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup"
|
||||
CLEANUP_CHECKPOINT = "cleanup_checkpoint"
|
||||
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
|
||||
|
||||
|
||||
@@ -245,7 +245,7 @@ class AirtableConnector(LoadConnector):
|
||||
return [(" ".join(combined) if combined else str(field_info), default_link)]
|
||||
|
||||
if isinstance(field_info, list):
|
||||
return [(str(item), default_link) for item in field_info]
|
||||
return [(item, default_link) for item in field_info]
|
||||
|
||||
return [(str(field_info), default_link)]
|
||||
|
||||
@@ -268,7 +268,7 @@ class AirtableConnector(LoadConnector):
|
||||
table_id: str,
|
||||
view_id: str | None,
|
||||
record_id: str,
|
||||
) -> tuple[list[Section], dict[str, str | list[str]]]:
|
||||
) -> tuple[list[Section], dict[str, Any]]:
|
||||
"""
|
||||
Process a single Airtable field and return sections or metadata.
|
||||
|
||||
@@ -342,7 +342,7 @@ class AirtableConnector(LoadConnector):
|
||||
record_id = record["id"]
|
||||
fields = record["fields"]
|
||||
sections: list[Section] = []
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
# Get primary field value if it exists
|
||||
primary_field_value = (
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -20,139 +15,48 @@ logger = setup_logger()
|
||||
TimeRange = tuple[datetime, datetime]
|
||||
|
||||
|
||||
class CheckpointOutputWrapper:
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format.
|
||||
The connector format is easier for the connector implementor (e.g. it enforces exactly
|
||||
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
|
||||
formats.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.next_checkpoint: ConnectorCheckpoint | None = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
checkpoint_connector_generator: CheckpointOutput,
|
||||
) -> Generator[
|
||||
tuple[Document | None, ConnectorFailure | None, ConnectorCheckpoint | None],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
# grabs the final return value and stores it in the `next_checkpoint` variable
|
||||
def _inner_wrapper(
|
||||
checkpoint_connector_generator: CheckpointOutput,
|
||||
) -> CheckpointOutput:
|
||||
self.next_checkpoint = yield from checkpoint_connector_generator
|
||||
return self.next_checkpoint # not used
|
||||
|
||||
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
|
||||
if isinstance(document_or_failure, Document):
|
||||
yield document_or_failure, None, None
|
||||
elif isinstance(document_or_failure, ConnectorFailure):
|
||||
yield None, document_or_failure, None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid document_or_failure type: {type(document_or_failure)}"
|
||||
)
|
||||
|
||||
if self.next_checkpoint is None:
|
||||
raise RuntimeError(
|
||||
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
|
||||
)
|
||||
|
||||
yield None, None, self.next_checkpoint
|
||||
|
||||
|
||||
class ConnectorRunner:
|
||||
"""
|
||||
Handles:
|
||||
- Batching
|
||||
- Additional exception logging
|
||||
- Combining different connector types to a single interface
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
batch_size: int,
|
||||
time_range: TimeRange | None = None,
|
||||
fail_loudly: bool = False,
|
||||
):
|
||||
self.connector = connector
|
||||
self.time_range = time_range
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.doc_batch: list[Document] = []
|
||||
if isinstance(self.connector, PollConnector):
|
||||
if time_range is None:
|
||||
raise ValueError("time_range is required for PollConnector")
|
||||
|
||||
def run(
|
||||
self, checkpoint: ConnectorCheckpoint
|
||||
) -> Generator[
|
||||
tuple[
|
||||
list[Document] | None, ConnectorFailure | None, ConnectorCheckpoint | None
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
self.doc_batch_generator = self.connector.poll_source(
|
||||
time_range[0].timestamp(), time_range[1].timestamp()
|
||||
)
|
||||
|
||||
elif isinstance(self.connector, LoadConnector):
|
||||
if time_range and fail_loudly:
|
||||
raise ValueError(
|
||||
"time_range specified, but passed in connector is not a PollConnector"
|
||||
)
|
||||
|
||||
self.doc_batch_generator = self.connector.load_from_state()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
|
||||
|
||||
def run(self) -> GenerateDocumentsOutput:
|
||||
"""Adds additional exception logging to the connector."""
|
||||
try:
|
||||
if isinstance(self.connector, CheckpointConnector):
|
||||
if self.time_range is None:
|
||||
raise ValueError("time_range is required for CheckpointConnector")
|
||||
start = time.monotonic()
|
||||
for batch in self.doc_batch_generator:
|
||||
# to know how long connector is taking
|
||||
logger.debug(
|
||||
f"Connector took {time.monotonic() - start} seconds to build a batch."
|
||||
)
|
||||
|
||||
yield batch
|
||||
|
||||
start = time.monotonic()
|
||||
checkpoint_connector_generator = self.connector.load_from_checkpoint(
|
||||
start=self.time_range[0].timestamp(),
|
||||
end=self.time_range[1].timestamp(),
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
next_checkpoint: ConnectorCheckpoint | None = None
|
||||
# this is guaranteed to always run at least once with next_checkpoint being non-None
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper()(
|
||||
checkpoint_connector_generator
|
||||
):
|
||||
if document is not None:
|
||||
self.doc_batch.append(document)
|
||||
|
||||
if failure is not None:
|
||||
yield None, failure, None
|
||||
|
||||
if len(self.doc_batch) >= self.batch_size:
|
||||
yield self.doc_batch, None, None
|
||||
self.doc_batch = []
|
||||
|
||||
# yield remaining documents
|
||||
if len(self.doc_batch) > 0:
|
||||
yield self.doc_batch, None, None
|
||||
self.doc_batch = []
|
||||
|
||||
yield None, None, next_checkpoint
|
||||
|
||||
logger.debug(
|
||||
f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint."
|
||||
)
|
||||
|
||||
else:
|
||||
finished_checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
finished_checkpoint.has_more = False
|
||||
|
||||
if isinstance(self.connector, PollConnector):
|
||||
if self.time_range is None:
|
||||
raise ValueError("time_range is required for PollConnector")
|
||||
|
||||
for document_batch in self.connector.poll_source(
|
||||
start=self.time_range[0].timestamp(),
|
||||
end=self.time_range[1].timestamp(),
|
||||
):
|
||||
yield document_batch, None, None
|
||||
|
||||
yield None, None, finished_checkpoint
|
||||
elif isinstance(self.connector, LoadConnector):
|
||||
for document_batch in self.connector.load_from_state():
|
||||
yield document_batch, None, None
|
||||
|
||||
yield None, None, finished_checkpoint
|
||||
else:
|
||||
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
|
||||
except Exception:
|
||||
exc_type, _, exc_traceback = sys.exc_info()
|
||||
|
||||
@@ -172,6 +76,6 @@ class ConnectorRunner:
|
||||
)
|
||||
logger.error(
|
||||
f"Error in connector. type: {exc_type};\n"
|
||||
f"local_vars below -> \n{local_vars_str[:1024]}"
|
||||
f"local_vars below -> \n{local_vars_str}"
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -30,14 +30,12 @@ from onyx.connectors.google_site.connector import GoogleSitesConnector
|
||||
from onyx.connectors.guru.connector import GuruConnector
|
||||
from onyx.connectors.hubspot.connector import HubSpotConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import EventConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.linear.connector import LinearConnector
|
||||
from onyx.connectors.loopio.connector import LoopioConnector
|
||||
from onyx.connectors.mediawiki.wiki import MediaWikiConnector
|
||||
from onyx.connectors.mock_connector.connector import MockConnector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.connectors.notion.connector import NotionConnector
|
||||
from onyx.connectors.onyx_jira.connector import JiraConnector
|
||||
@@ -45,7 +43,7 @@ from onyx.connectors.productboard.connector import ProductboardConnector
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.connectors.slab.connector import SlabConnector
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.connectors.slack.connector import SlackPollConnector
|
||||
from onyx.connectors.teams.connector import TeamsConnector
|
||||
from onyx.connectors.web.connector import WebConnector
|
||||
from onyx.connectors.wikipedia.connector import WikipediaConnector
|
||||
@@ -68,8 +66,8 @@ def identify_connector_class(
|
||||
DocumentSource.WEB: WebConnector,
|
||||
DocumentSource.FILE: LocalFileConnector,
|
||||
DocumentSource.SLACK: {
|
||||
InputType.POLL: SlackConnector,
|
||||
InputType.SLIM_RETRIEVAL: SlackConnector,
|
||||
InputType.POLL: SlackPollConnector,
|
||||
InputType.SLIM_RETRIEVAL: SlackPollConnector,
|
||||
},
|
||||
DocumentSource.GITHUB: GithubConnector,
|
||||
DocumentSource.GMAIL: GmailConnector,
|
||||
@@ -111,8 +109,6 @@ def identify_connector_class(
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
# just for integration tests
|
||||
DocumentSource.MOCK_CONNECTOR: MockConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
@@ -129,23 +125,10 @@ def identify_connector_class(
|
||||
|
||||
if any(
|
||||
[
|
||||
(
|
||||
input_type == InputType.LOAD_STATE
|
||||
and not issubclass(connector, LoadConnector)
|
||||
),
|
||||
(
|
||||
input_type == InputType.POLL
|
||||
# either poll or checkpoint works for this, in the future
|
||||
# all connectors should be checkpoint connectors
|
||||
and (
|
||||
not issubclass(connector, PollConnector)
|
||||
and not issubclass(connector, CheckpointConnector)
|
||||
)
|
||||
),
|
||||
(
|
||||
input_type == InputType.EVENT
|
||||
and not issubclass(connector, EventConnector)
|
||||
),
|
||||
input_type == InputType.LOAD_STATE
|
||||
and not issubclass(connector, LoadConnector),
|
||||
input_type == InputType.POLL and not issubclass(connector, PollConnector),
|
||||
input_type == InputType.EVENT and not issubclass(connector, EventConnector),
|
||||
]
|
||||
):
|
||||
raise ConnectorMissingException(
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -17,7 +14,6 @@ SecondsSinceUnixEpoch = float
|
||||
|
||||
GenerateDocumentsOutput = Iterator[list[Document]]
|
||||
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
|
||||
CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]
|
||||
|
||||
|
||||
class BaseConnector(abc.ABC):
|
||||
@@ -109,33 +105,3 @@ class EventConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def handle_event(self, event: Any) -> GenerateDocumentsOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CheckpointConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
"""Yields back documents or failures. Final return is the new checkpoint.
|
||||
|
||||
Final return can be access via either:
|
||||
|
||||
```
|
||||
try:
|
||||
for document_or_failure in connector.load_from_checkpoint(start, end, checkpoint):
|
||||
print(document_or_failure)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value # Extracting the return value
|
||||
print(checkpoint)
|
||||
```
|
||||
|
||||
OR
|
||||
|
||||
```
|
||||
checkpoint = yield from connector.load_from_checkpoint(start, end, checkpoint)
|
||||
```
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class SingleConnectorYield(BaseModel):
|
||||
documents: list[Document]
|
||||
checkpoint: ConnectorCheckpoint
|
||||
failures: list[ConnectorFailure]
|
||||
unhandled_exception: str | None = None
|
||||
|
||||
|
||||
class MockConnector(CheckpointConnector):
|
||||
def __init__(
|
||||
self,
|
||||
mock_server_host: str,
|
||||
mock_server_port: int,
|
||||
) -> None:
|
||||
self.mock_server_host = mock_server_host
|
||||
self.mock_server_port = mock_server_port
|
||||
self.client = httpx.Client(timeout=30.0)
|
||||
|
||||
self.connector_yields: list[SingleConnectorYield] | None = None
|
||||
self.current_yield_index: int = 0
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
response = self.client.get(self._get_mock_server_url("get-documents"))
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
self.connector_yields = [
|
||||
SingleConnectorYield(**yield_data) for yield_data in data
|
||||
]
|
||||
return None
|
||||
|
||||
def _get_mock_server_url(self, endpoint: str) -> str:
|
||||
return f"http://{self.mock_server_host}:{self.mock_server_port}/{endpoint}"
|
||||
|
||||
def _save_checkpoint(self, checkpoint: ConnectorCheckpoint) -> None:
|
||||
response = self.client.post(
|
||||
self._get_mock_server_url("add-checkpoint"),
|
||||
json=checkpoint.model_dump(mode="json"),
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
if self.connector_yields is None:
|
||||
raise ValueError("No connector yields configured")
|
||||
|
||||
# Save the checkpoint to the mock server
|
||||
self._save_checkpoint(checkpoint)
|
||||
|
||||
yield_index = self.current_yield_index
|
||||
self.current_yield_index += 1
|
||||
current_yield = self.connector_yields[yield_index]
|
||||
|
||||
# If the current yield has an unhandled exception, raise it
|
||||
# This is used to simulate an unhandled failure in the connector.
|
||||
if current_yield.unhandled_exception:
|
||||
raise RuntimeError(current_yield.unhandled_exception)
|
||||
|
||||
# yield all documents
|
||||
for document in current_yield.documents:
|
||||
yield document
|
||||
|
||||
for failure in current_yield.failures:
|
||||
yield failure
|
||||
|
||||
return current_yield.checkpoint
|
||||
@@ -3,7 +3,6 @@ from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
@@ -188,48 +187,36 @@ class SlimDocument(BaseModel):
|
||||
perm_sync_data: Any | None = None
|
||||
|
||||
|
||||
class IndexAttemptMetadata(BaseModel):
|
||||
batch_num: int | None = None
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
|
||||
|
||||
class ConnectorCheckpoint(BaseModel):
|
||||
# TODO: maybe move this to something disk-based to handle extremely large checkpoints?
|
||||
checkpoint_content: dict
|
||||
has_more: bool
|
||||
class DocumentErrorSummary(BaseModel):
|
||||
id: str
|
||||
semantic_id: str
|
||||
section_link: str | None
|
||||
|
||||
@classmethod
|
||||
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
|
||||
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
|
||||
def from_document(cls, doc: Document) -> "DocumentErrorSummary":
|
||||
section_link = doc.sections[0].link if len(doc.sections) > 0 else None
|
||||
return cls(
|
||||
id=doc.id, semantic_id=doc.semantic_identifier, section_link=section_link
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "DocumentErrorSummary":
|
||||
return cls(
|
||||
id=str(data.get("id")),
|
||||
semantic_id=str(data.get("semantic_id")),
|
||||
section_link=str(data.get("section_link")),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, str | None]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"semantic_id": self.semantic_id,
|
||||
"section_link": self.section_link,
|
||||
}
|
||||
|
||||
|
||||
class DocumentFailure(BaseModel):
|
||||
document_id: str
|
||||
document_link: str | None = None
|
||||
|
||||
|
||||
class EntityFailure(BaseModel):
|
||||
entity_id: str
|
||||
missed_time_range: tuple[datetime, datetime] | None = None
|
||||
|
||||
|
||||
class ConnectorFailure(BaseModel):
|
||||
failed_document: DocumentFailure | None = None
|
||||
failed_entity: EntityFailure | None = None
|
||||
failure_message: str
|
||||
exception: Exception | None = None
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_failed_fields(cls, values: dict) -> dict:
|
||||
failed_document = values.get("failed_document")
|
||||
failed_entity = values.get("failed_entity")
|
||||
if (failed_document is None and failed_entity is None) or (
|
||||
failed_document is not None and failed_entity is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"Exactly one of 'failed_document' or 'failed_entity' must be specified."
|
||||
)
|
||||
return values
|
||||
class IndexAttemptMetadata(BaseModel):
|
||||
batch_num: int | None = None
|
||||
num_exceptions: int = 0
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
|
||||
@@ -1,16 +1,10 @@
|
||||
import contextvars
|
||||
import copy
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypedDict
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
@@ -18,18 +12,14 @@ from slack_sdk.errors import SlackApiError
|
||||
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
@@ -43,8 +33,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SLACK_LIMIT = 900
|
||||
|
||||
|
||||
ChannelType = dict[str, Any]
|
||||
MessageType = dict[str, Any]
|
||||
@@ -52,13 +40,6 @@ MessageType = dict[str, Any]
|
||||
ThreadType = list[MessageType]
|
||||
|
||||
|
||||
class SlackCheckpointContent(TypedDict):
|
||||
channel_ids: list[str]
|
||||
channel_completion_map: dict[str, str]
|
||||
current_channel: ChannelType | None
|
||||
seen_thread_ts: list[str]
|
||||
|
||||
|
||||
def _collect_paginated_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool,
|
||||
@@ -159,10 +140,6 @@ def get_latest_message_time(thread: ThreadType) -> datetime:
|
||||
return datetime.fromtimestamp(max_ts, tz=timezone.utc)
|
||||
|
||||
|
||||
def _build_doc_id(channel_id: str, thread_ts: str) -> str:
|
||||
return f"{channel_id}__{thread_ts}"
|
||||
|
||||
|
||||
def thread_to_doc(
|
||||
channel: ChannelType,
|
||||
thread: ThreadType,
|
||||
@@ -205,7 +182,7 @@ def thread_to_doc(
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
|
||||
id=f"{channel_id}__{thread[0]['ts']}",
|
||||
sections=[
|
||||
Section(
|
||||
link=get_message_link(event=m, client=client, channel_id=channel_id),
|
||||
@@ -290,97 +267,64 @@ def filter_channels(
|
||||
]
|
||||
|
||||
|
||||
def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType:
|
||||
"""Get a channel by its ID.
|
||||
|
||||
Args:
|
||||
client: The Slack WebClient instance
|
||||
channel_id: The ID of the channel to fetch
|
||||
|
||||
Returns:
|
||||
The channel information
|
||||
|
||||
Raises:
|
||||
SlackApiError: If the channel cannot be fetched
|
||||
"""
|
||||
response = make_slack_api_call_w_retries(
|
||||
client.conversations_info,
|
||||
channel=channel_id,
|
||||
)
|
||||
return cast(ChannelType, response["channel"])
|
||||
|
||||
|
||||
def _get_messages(
|
||||
channel: ChannelType,
|
||||
def _get_all_docs(
|
||||
client: WebClient,
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
) -> tuple[list[MessageType], bool]:
|
||||
"""Slack goes from newest to oldest."""
|
||||
|
||||
# have to be in the channel in order to read messages
|
||||
if not channel["is_member"]:
|
||||
make_slack_api_call_w_retries(
|
||||
client.conversations_join,
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
logger.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
response = make_slack_api_call_w_retries(
|
||||
client.conversations_history,
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
latest=latest,
|
||||
limit=_SLACK_LIMIT,
|
||||
)
|
||||
response.validate()
|
||||
|
||||
messages = cast(list[MessageType], response.get("messages", []))
|
||||
|
||||
cursor = cast(dict[str, Any], response.get("response_metadata", {})).get(
|
||||
"next_cursor", ""
|
||||
)
|
||||
has_more = bool(cursor)
|
||||
return messages, has_more
|
||||
|
||||
|
||||
def _message_to_doc(
|
||||
message: MessageType,
|
||||
client: WebClient,
|
||||
channel: ChannelType,
|
||||
slack_cleaner: SlackTextCleaner,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
seen_thread_ts: set[str],
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> Document | None:
|
||||
filtered_thread: ThreadType | None = None
|
||||
thread_ts = message.get("thread_ts")
|
||||
if thread_ts:
|
||||
# skip threads we've already seen, since we've already processed all
|
||||
# messages in that thread
|
||||
if thread_ts in seen_thread_ts:
|
||||
return None
|
||||
) -> Generator[Document, None, None]:
|
||||
"""Get all documents in the workspace, channel by channel"""
|
||||
slack_cleaner = SlackTextCleaner(client=client)
|
||||
|
||||
thread = get_thread(
|
||||
client=client, channel_id=channel["id"], thread_id=thread_ts
|
||||
)
|
||||
filtered_thread = [
|
||||
message for message in thread if not msg_filter_func(message)
|
||||
]
|
||||
elif not msg_filter_func(message):
|
||||
filtered_thread = [message]
|
||||
# Cache to prevent refetching via API since users
|
||||
user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||
|
||||
if filtered_thread:
|
||||
return thread_to_doc(
|
||||
channel=channel,
|
||||
thread=filtered_thread,
|
||||
slack_cleaner=slack_cleaner,
|
||||
client=client,
|
||||
user_cache=user_cache,
|
||||
all_channels = get_channels(client)
|
||||
filtered_channels = filter_channels(
|
||||
all_channels, channels, channel_name_regex_enabled
|
||||
)
|
||||
|
||||
for channel in filtered_channels:
|
||||
channel_docs = 0
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client, channel=channel, oldest=oldest, latest=latest
|
||||
)
|
||||
|
||||
return None
|
||||
seen_thread_ts: set[str] = set()
|
||||
for message_batch in channel_message_batches:
|
||||
for message in message_batch:
|
||||
filtered_thread: ThreadType | None = None
|
||||
thread_ts = message.get("thread_ts")
|
||||
if thread_ts:
|
||||
# skip threads we've already seen, since we've already processed all
|
||||
# messages in that thread
|
||||
if thread_ts in seen_thread_ts:
|
||||
continue
|
||||
seen_thread_ts.add(thread_ts)
|
||||
thread = get_thread(
|
||||
client=client, channel_id=channel["id"], thread_id=thread_ts
|
||||
)
|
||||
filtered_thread = [
|
||||
message for message in thread if not msg_filter_func(message)
|
||||
]
|
||||
elif not msg_filter_func(message):
|
||||
filtered_thread = [message]
|
||||
|
||||
if filtered_thread:
|
||||
channel_docs += 1
|
||||
yield thread_to_doc(
|
||||
channel=channel,
|
||||
thread=filtered_thread,
|
||||
slack_cleaner=slack_cleaner,
|
||||
client=client,
|
||||
user_cache=user_cache,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Pulled {channel_docs} documents from slack channel {channel['name']}"
|
||||
)
|
||||
|
||||
|
||||
def _get_all_doc_ids(
|
||||
@@ -424,7 +368,7 @@ def _get_all_doc_ids(
|
||||
for message_ts in message_ts_set:
|
||||
channel_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=_build_doc_id(channel_id=channel_id, thread_ts=message_ts),
|
||||
id=f"{channel_id}__{message_ts}",
|
||||
perm_sync_data={"channel_id": channel_id},
|
||||
)
|
||||
)
|
||||
@@ -432,51 +376,7 @@ def _get_all_doc_ids(
|
||||
yield channel_metadata_list
|
||||
|
||||
|
||||
def _process_message(
|
||||
message: MessageType,
|
||||
client: WebClient,
|
||||
channel: ChannelType,
|
||||
slack_cleaner: SlackTextCleaner,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
seen_thread_ts: set[str],
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> tuple[Document | None, str | None, ConnectorFailure | None]:
|
||||
thread_ts = message.get("thread_ts")
|
||||
try:
|
||||
# causes random failures for testing checkpointing / continue on failure
|
||||
# import random
|
||||
# if random.random() > 0.95:
|
||||
# raise RuntimeError("Random failure :P")
|
||||
|
||||
doc = _message_to_doc(
|
||||
message=message,
|
||||
client=client,
|
||||
channel=channel,
|
||||
slack_cleaner=slack_cleaner,
|
||||
user_cache=user_cache,
|
||||
seen_thread_ts=seen_thread_ts,
|
||||
msg_filter_func=msg_filter_func,
|
||||
)
|
||||
return (doc, thread_ts, None)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing message {message['ts']}")
|
||||
return (
|
||||
None,
|
||||
thread_ts,
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=_build_doc_id(
|
||||
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
|
||||
),
|
||||
document_link=get_message_link(message, client, channel["id"]),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
class SlackPollConnector(PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
channels: list[str] | None = None,
|
||||
@@ -490,14 +390,9 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
self.batch_size = batch_size
|
||||
self.client: WebClient | None = None
|
||||
|
||||
# just used for efficiency
|
||||
self.text_cleaner: SlackTextCleaner | None = None
|
||||
self.user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
bot_token = credentials["slack_bot_token"]
|
||||
self.client = WebClient(token=bot_token)
|
||||
self.text_cleaner = SlackTextCleaner(client=self.client)
|
||||
return None
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
@@ -516,155 +411,30 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
"""Rough outline:
|
||||
|
||||
Step 1: Get all channels, yield back Checkpoint.
|
||||
Step 2: Loop through each channel. For each channel:
|
||||
Step 2.1: Get messages within the time range.
|
||||
Step 2.2: Process messages in parallel, yield back docs.
|
||||
Step 2.3: Update checkpoint with new_latest, seen_thread_ts, and current_channel.
|
||||
Slack returns messages from newest to oldest, so we need to keep track of
|
||||
the latest message we've seen in each channel.
|
||||
Step 2.4: If there are no more messages in the channel, switch the current
|
||||
channel to the next channel.
|
||||
"""
|
||||
if self.client is None or self.text_cleaner is None:
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
checkpoint_content = cast(
|
||||
SlackCheckpointContent,
|
||||
(
|
||||
copy.deepcopy(checkpoint.checkpoint_content)
|
||||
or {
|
||||
"channel_ids": None,
|
||||
"channel_completion_map": {},
|
||||
"current_channel": None,
|
||||
"seen_thread_ts": [],
|
||||
}
|
||||
),
|
||||
)
|
||||
documents: list[Document] = []
|
||||
for document in _get_all_docs(
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
# NOTE: need to impute to `None` instead of using 0.0, since Slack will
|
||||
# throw an error if we use 0.0 on an account without infinite data
|
||||
# retention
|
||||
oldest=str(start) if start else None,
|
||||
latest=str(end),
|
||||
):
|
||||
documents.append(document)
|
||||
if len(documents) >= self.batch_size:
|
||||
yield documents
|
||||
documents = []
|
||||
|
||||
# if this is the very first time we've called this, need to
|
||||
# get all relevant channels and save them into the checkpoint
|
||||
if checkpoint_content["channel_ids"] is None:
|
||||
raw_channels = get_channels(self.client)
|
||||
filtered_channels = filter_channels(
|
||||
raw_channels, self.channels, self.channel_regex_enabled
|
||||
)
|
||||
if len(filtered_channels) == 0:
|
||||
return checkpoint
|
||||
|
||||
checkpoint_content["channel_ids"] = [c["id"] for c in filtered_channels]
|
||||
checkpoint_content["current_channel"] = filtered_channels[0]
|
||||
checkpoint = ConnectorCheckpoint(
|
||||
checkpoint_content=checkpoint_content, # type: ignore
|
||||
has_more=True,
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
final_channel_ids = checkpoint_content["channel_ids"]
|
||||
channel = checkpoint_content["current_channel"]
|
||||
if channel is None:
|
||||
raise ValueError("current_channel key not found in checkpoint")
|
||||
|
||||
channel_id = channel["id"]
|
||||
if channel_id not in final_channel_ids:
|
||||
raise ValueError(f"Channel {channel_id} not found in checkpoint")
|
||||
|
||||
oldest = str(start) if start else None
|
||||
latest = checkpoint_content["channel_completion_map"].get(channel_id, str(end))
|
||||
seen_thread_ts = set(checkpoint_content["seen_thread_ts"])
|
||||
try:
|
||||
logger.debug(
|
||||
f"Getting messages for channel {channel} within range {oldest} - {latest}"
|
||||
)
|
||||
message_batch, has_more_in_channel = _get_messages(
|
||||
channel, self.client, oldest, latest
|
||||
)
|
||||
new_latest = message_batch[-1]["ts"] if message_batch else latest
|
||||
|
||||
# Process messages in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
futures: list[Future] = []
|
||||
for message in message_batch:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
current_context = contextvars.copy_context()
|
||||
futures.append(
|
||||
executor.submit(
|
||||
current_context.run,
|
||||
_process_message,
|
||||
message=message,
|
||||
client=self.client,
|
||||
channel=channel,
|
||||
slack_cleaner=self.text_cleaner,
|
||||
user_cache=self.user_cache,
|
||||
seen_thread_ts=seen_thread_ts,
|
||||
)
|
||||
)
|
||||
|
||||
for future in as_completed(futures):
|
||||
doc, thread_ts, failures = future.result()
|
||||
if doc:
|
||||
# handle race conditions here since this is single
|
||||
# threaded. Multi-threaded _process_message reads from this
|
||||
# but since this is single threaded, we won't run into simul
|
||||
# writes. At worst, we can duplicate a thread, which will be
|
||||
# deduped later on.
|
||||
if thread_ts not in seen_thread_ts:
|
||||
yield doc
|
||||
|
||||
if thread_ts:
|
||||
seen_thread_ts.add(thread_ts)
|
||||
elif failures:
|
||||
for failure in failures:
|
||||
yield failure
|
||||
|
||||
checkpoint_content["seen_thread_ts"] = list(seen_thread_ts)
|
||||
checkpoint_content["channel_completion_map"][channel["id"]] = new_latest
|
||||
if has_more_in_channel:
|
||||
checkpoint_content["current_channel"] = channel
|
||||
else:
|
||||
new_channel_id = next(
|
||||
(
|
||||
channel_id
|
||||
for channel_id in final_channel_ids
|
||||
if channel_id
|
||||
not in checkpoint_content["channel_completion_map"]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if new_channel_id:
|
||||
new_channel = _get_channel_by_id(self.client, new_channel_id)
|
||||
checkpoint_content["current_channel"] = new_channel
|
||||
else:
|
||||
checkpoint_content["current_channel"] = None
|
||||
|
||||
checkpoint = ConnectorCheckpoint(
|
||||
checkpoint_content=checkpoint_content, # type: ignore
|
||||
has_more=checkpoint_content["current_channel"] is not None,
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing channel {channel['name']}")
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=channel["id"],
|
||||
missed_time_range=(
|
||||
datetime.fromtimestamp(start, tz=timezone.utc),
|
||||
datetime.fromtimestamp(end, tz=timezone.utc),
|
||||
),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
return checkpoint
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -672,7 +442,7 @@ if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
slack_channel = os.environ.get("SLACK_CHANNEL")
|
||||
connector = SlackConnector(
|
||||
connector = SlackPollConnector(
|
||||
channels=[slack_channel] if slack_channel else None,
|
||||
)
|
||||
connector.load_credentials({"slack_bot_token": os.environ["SLACK_BOT_TOKEN"]})
|
||||
@@ -680,17 +450,6 @@ if __name__ == "__main__":
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
document_batches = connector.poll_source(one_day_ago, current)
|
||||
|
||||
gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint)
|
||||
try:
|
||||
for document_or_failure in gen:
|
||||
if isinstance(document_or_failure, Document):
|
||||
print(document_or_failure)
|
||||
elif isinstance(document_or_failure, ConnectorFailure):
|
||||
print(document_or_failure)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
print("Next checkpoint:", checkpoint)
|
||||
|
||||
print("Next checkpoint:", checkpoint)
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -34,14 +34,9 @@ def get_message_link(
|
||||
) -> str:
|
||||
channel_id = channel_id or event["channel"]
|
||||
message_ts = event["ts"]
|
||||
message_ts_without_dot = message_ts.replace(".", "")
|
||||
thread_ts = event.get("thread_ts")
|
||||
base_url = get_base_url(client.token)
|
||||
|
||||
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (
|
||||
f"?thread_ts={thread_ts}" if thread_ts else ""
|
||||
)
|
||||
return link
|
||||
response = client.chat_getPermalink(channel=channel_id, message_ts=message_ts)
|
||||
permalink = response["permalink"]
|
||||
return permalink
|
||||
|
||||
|
||||
def _make_slack_api_call_paginated(
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
import os
|
||||
import tempfile
|
||||
import urllib.parse
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from zulip import Client
|
||||
|
||||
@@ -41,39 +36,8 @@ class ZulipConnector(LoadConnector, PollConnector):
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.realm_name = realm_name
|
||||
|
||||
# Clean and normalize the URL
|
||||
realm_url = realm_url.strip().lower()
|
||||
|
||||
# Remove any trailing slashes
|
||||
realm_url = realm_url.rstrip("/")
|
||||
|
||||
# Ensure the URL has a scheme
|
||||
if not realm_url.startswith(("http://", "https://")):
|
||||
realm_url = f"https://{realm_url}"
|
||||
|
||||
try:
|
||||
parsed = urllib.parse.urlparse(realm_url)
|
||||
|
||||
# Extract the base domain without any paths or ports
|
||||
netloc = parsed.netloc.split(":")[0] # Remove port if present
|
||||
|
||||
if not netloc:
|
||||
raise ValueError(
|
||||
f"Invalid realm URL format: {realm_url}. "
|
||||
f"URL must include a valid domain name."
|
||||
)
|
||||
|
||||
# Always use HTTPS for security
|
||||
self.base_url = f"https://{netloc}"
|
||||
self.client: Client | None = None
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to parse Zulip realm URL: {realm_url}. "
|
||||
f"Please provide a URL in the format: domain.com or https://domain.com. "
|
||||
f"Error: {str(e)}"
|
||||
)
|
||||
self.realm_url = realm_url if realm_url.endswith("/") else realm_url + "/"
|
||||
self.client: Client | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
contents = credentials["zuliprc_content"]
|
||||
@@ -91,17 +55,12 @@ class ZulipConnector(LoadConnector, PollConnector):
|
||||
return None
|
||||
|
||||
def _message_to_narrow_link(self, m: Message) -> str:
|
||||
try:
|
||||
stream_name = m.display_recipient # assume str
|
||||
stream_operand = encode_zulip_narrow_operand(f"{m.stream_id}-{stream_name}")
|
||||
topic_operand = encode_zulip_narrow_operand(m.subject)
|
||||
stream_name = m.display_recipient # assume str
|
||||
stream_operand = encode_zulip_narrow_operand(f"{m.stream_id}-{stream_name}")
|
||||
topic_operand = encode_zulip_narrow_operand(m.subject)
|
||||
|
||||
narrow_link = f"{self.base_url}#narrow/stream/{stream_operand}/topic/{topic_operand}/near/{m.id}"
|
||||
return narrow_link
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating Zulip message link: {e}")
|
||||
# Fallback to a basic link that at least includes the base URL
|
||||
return f"{self.base_url}#narrow/id/{m.id}"
|
||||
narrow_link = f"{self.realm_url}#narrow/stream/{stream_operand}/topic/{topic_operand}/near/{m.id}"
|
||||
return narrow_link
|
||||
|
||||
def _get_message_batch(self, anchor: str) -> Tuple[bool, List[Message]]:
|
||||
if self.client is None:
|
||||
@@ -124,40 +83,6 @@ class ZulipConnector(LoadConnector, PollConnector):
|
||||
def _message_to_doc(self, message: Message) -> Document:
|
||||
text = f"{message.sender_full_name}: {message.content}"
|
||||
|
||||
try:
|
||||
# Convert timestamps to UTC datetime objects
|
||||
post_time = datetime.fromtimestamp(message.timestamp, tz=timezone.utc)
|
||||
edit_time = (
|
||||
datetime.fromtimestamp(message.last_edit_timestamp, tz=timezone.utc)
|
||||
if message.last_edit_timestamp is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Use the most recent edit time if available, otherwise use post time
|
||||
doc_time = edit_time if edit_time is not None else post_time
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"Failed to parse timestamp for message {message.id}: {e}")
|
||||
post_time = None
|
||||
edit_time = None
|
||||
doc_time = None
|
||||
|
||||
metadata: Dict[str, Union[str, List[str]]] = {
|
||||
"stream_name": str(message.display_recipient),
|
||||
"topic": str(message.subject),
|
||||
"sender_name": str(message.sender_full_name),
|
||||
"sender_email": str(message.sender_email),
|
||||
"message_timestamp": str(message.timestamp),
|
||||
"message_id": str(message.id),
|
||||
"stream_id": str(message.stream_id),
|
||||
"has_reactions": str(len(message.reactions) > 0),
|
||||
"content_type": str(message.content_type or "text"),
|
||||
}
|
||||
|
||||
# Always include edit timestamp in metadata when available
|
||||
if edit_time is not None:
|
||||
metadata["edit_timestamp"] = str(message.last_edit_timestamp)
|
||||
|
||||
return Document(
|
||||
id=f"{message.stream_id}__{message.id}",
|
||||
sections=[
|
||||
@@ -167,9 +92,8 @@ class ZulipConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
],
|
||||
source=DocumentSource.ZULIP,
|
||||
semantic_identifier=f"{message.display_recipient} > {message.subject}",
|
||||
metadata=metadata,
|
||||
doc_updated_at=doc_time, # Use most recent edit time or post time
|
||||
semantic_identifier=message.display_recipient or message.subject,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
def _get_docs(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -20,7 +19,7 @@ class Message(BaseModel):
|
||||
sender_realm_str: str
|
||||
subject: str
|
||||
topic_links: Optional[List[Any]] = None
|
||||
last_edit_timestamp: Optional[int] = None
|
||||
last_edit_timestamp: Optional[int]
|
||||
edit_history: Any = None
|
||||
reactions: List[Any]
|
||||
submessages: List[Any]
|
||||
@@ -40,5 +39,5 @@ class GetMessagesResponse(BaseModel):
|
||||
found_oldest: Optional[bool] = None
|
||||
found_newest: Optional[bool] = None
|
||||
history_limited: Optional[bool] = None
|
||||
anchor: Optional[Union[str, int]] = None
|
||||
anchor: Optional[str] = None
|
||||
messages: List[Message] = Field(default_factory=list)
|
||||
|
||||
@@ -628,7 +628,7 @@ def create_new_chat_message(
|
||||
commit: bool = True,
|
||||
reserved_message_id: int | None = None,
|
||||
overridden_model: str | None = None,
|
||||
refined_answer_improvement: bool | None = None,
|
||||
refined_answer_improvement: bool = True,
|
||||
) -> ChatMessage:
|
||||
if reserved_message_id is not None:
|
||||
# Edit existing message
|
||||
|
||||
@@ -18,7 +18,6 @@ import boto3
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
@@ -40,7 +39,6 @@ from onyx.configs.app_configs import POSTGRES_PASSWORD
|
||||
from onyx.configs.app_configs import POSTGRES_POOL_PRE_PING
|
||||
from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
@@ -189,38 +187,20 @@ class SqlEngine:
|
||||
_engine: Engine | None = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
||||
DEFAULT_ENGINE_KWARGS = {
|
||||
"pool_size": 20,
|
||||
"max_overflow": 5,
|
||||
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
|
||||
"pool_recycle": POSTGRES_POOL_RECYCLE,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
|
||||
)
|
||||
|
||||
# Start with base kwargs that are valid for all pool types
|
||||
final_engine_kwargs: dict[str, Any] = {}
|
||||
|
||||
if POSTGRES_USE_NULL_POOL:
|
||||
# if null pool is specified, then we need to make sure that
|
||||
# we remove any passed in kwargs related to pool size that would
|
||||
# cause the initialization to fail
|
||||
final_engine_kwargs.update(engine_kwargs)
|
||||
|
||||
final_engine_kwargs["poolclass"] = pool.NullPool
|
||||
if "pool_size" in final_engine_kwargs:
|
||||
del final_engine_kwargs["pool_size"]
|
||||
if "max_overflow" in final_engine_kwargs:
|
||||
del final_engine_kwargs["max_overflow"]
|
||||
else:
|
||||
final_engine_kwargs["pool_size"] = 20
|
||||
final_engine_kwargs["max_overflow"] = 5
|
||||
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
|
||||
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
|
||||
|
||||
# any passed in kwargs override the defaults
|
||||
final_engine_kwargs.update(engine_kwargs)
|
||||
|
||||
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
||||
engine = create_engine(connection_string, **final_engine_kwargs)
|
||||
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
|
||||
engine = create_engine(connection_string, **merged_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
event.listen(engine, "do_connect", provide_iam_token)
|
||||
@@ -319,21 +299,13 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
|
||||
connect_args["ssl"] = ssl_context
|
||||
|
||||
engine_kwargs = {
|
||||
"connect_args": connect_args,
|
||||
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
|
||||
"pool_recycle": POSTGRES_POOL_RECYCLE,
|
||||
}
|
||||
|
||||
if POSTGRES_USE_NULL_POOL:
|
||||
engine_kwargs["poolclass"] = pool.NullPool
|
||||
else:
|
||||
engine_kwargs["pool_size"] = POSTGRES_API_SERVER_POOL_SIZE
|
||||
engine_kwargs["max_overflow"] = POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
|
||||
_ASYNC_ENGINE = create_async_engine(
|
||||
connection_string,
|
||||
**engine_kwargs,
|
||||
connect_args=connect_args,
|
||||
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
|
||||
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@@ -11,7 +11,8 @@ from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentErrorSummary
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
@@ -40,27 +41,6 @@ def get_last_attempt_for_cc_pair(
|
||||
)
|
||||
|
||||
|
||||
def get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
limit: int,
|
||||
db_session: Session,
|
||||
) -> list[IndexAttempt]:
|
||||
return (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair_id,
|
||||
IndexAttempt.search_settings_id == search_settings_id,
|
||||
IndexAttempt.status.notin_(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
.order_by(IndexAttempt.time_updated.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_index_attempt(
|
||||
db_session: Session, index_attempt_id: int
|
||||
) -> IndexAttempt | None:
|
||||
@@ -635,32 +615,23 @@ def count_unique_cc_pairs_with_successful_index_attempts(
|
||||
|
||||
def create_index_attempt_error(
|
||||
index_attempt_id: int | None,
|
||||
connector_credential_pair_id: int,
|
||||
failure: ConnectorFailure,
|
||||
batch: int | None,
|
||||
docs: list[Document],
|
||||
exception_msg: str,
|
||||
exception_traceback: str,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
doc_summaries = []
|
||||
for doc in docs:
|
||||
doc_summary = DocumentErrorSummary.from_document(doc)
|
||||
doc_summaries.append(doc_summary.to_dict())
|
||||
|
||||
new_error = IndexAttemptError(
|
||||
index_attempt_id=index_attempt_id,
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
document_id=(
|
||||
failure.failed_document.document_id if failure.failed_document else None
|
||||
),
|
||||
document_link=(
|
||||
failure.failed_document.document_link if failure.failed_document else None
|
||||
),
|
||||
entity_id=(failure.failed_entity.entity_id if failure.failed_entity else None),
|
||||
failed_time_range_start=(
|
||||
failure.failed_entity.missed_time_range[0]
|
||||
if failure.failed_entity and failure.failed_entity.missed_time_range
|
||||
else None
|
||||
),
|
||||
failed_time_range_end=(
|
||||
failure.failed_entity.missed_time_range[1]
|
||||
if failure.failed_entity and failure.failed_entity.missed_time_range
|
||||
else None
|
||||
),
|
||||
failure_message=failure.failure_message,
|
||||
is_resolved=False,
|
||||
batch=batch,
|
||||
doc_summaries=doc_summaries,
|
||||
error_msg=exception_msg,
|
||||
traceback=exception_traceback,
|
||||
)
|
||||
db_session.add(new_error)
|
||||
db_session.commit()
|
||||
@@ -678,42 +649,3 @@ def get_index_attempt_errors(
|
||||
|
||||
errors = db_session.scalars(stmt)
|
||||
return list(errors.all())
|
||||
|
||||
|
||||
def count_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
unresolved_only: bool,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(IndexAttemptError)
|
||||
.where(IndexAttemptError.connector_credential_pair_id == cc_pair_id)
|
||||
)
|
||||
if unresolved_only:
|
||||
stmt = stmt.where(IndexAttemptError.is_resolved.is_(False))
|
||||
|
||||
result = db_session.scalar(stmt)
|
||||
return 0 if result is None else result
|
||||
|
||||
|
||||
def get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
unresolved_only: bool,
|
||||
db_session: Session,
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
) -> list[IndexAttemptError]:
|
||||
stmt = select(IndexAttemptError).where(
|
||||
IndexAttemptError.connector_credential_pair_id == cc_pair_id
|
||||
)
|
||||
if unresolved_only:
|
||||
stmt = stmt.where(IndexAttemptError.is_resolved.is_(False))
|
||||
|
||||
# Order by most recent first
|
||||
stmt = stmt.order_by(desc(IndexAttemptError.time_created))
|
||||
|
||||
if page is not None and page_size is not None:
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
@@ -827,19 +827,6 @@ class IndexAttempt(Base):
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# for polling connectors, the start and end time of the poll window
|
||||
# will be set when the index attempt starts
|
||||
poll_range_start: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, default=None
|
||||
)
|
||||
poll_range_end: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, default=None
|
||||
)
|
||||
|
||||
# Points to the last checkpoint that was saved for this run. The pointer here
|
||||
# can be taken to the FileStore to grab the actual checkpoint value
|
||||
checkpoint_pointer: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
@@ -883,13 +870,6 @@ class IndexAttempt(Base):
|
||||
desc("time_updated"),
|
||||
unique=False,
|
||||
),
|
||||
Index(
|
||||
"ix_index_attempt_cc_pair_settings_poll",
|
||||
"connector_credential_pair_id",
|
||||
"search_settings_id",
|
||||
"status",
|
||||
desc("time_updated"),
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -906,33 +886,25 @@ class IndexAttempt(Base):
|
||||
|
||||
|
||||
class IndexAttemptError(Base):
|
||||
"""
|
||||
Represents an error that was encountered during an IndexAttempt.
|
||||
"""
|
||||
|
||||
__tablename__ = "index_attempt_errors"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
index_attempt_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("index_attempt.id"),
|
||||
nullable=False,
|
||||
)
|
||||
connector_credential_pair_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"),
|
||||
nullable=False,
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
document_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
document_link: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
entity_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
failed_time_range_start: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
failed_time_range_end: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
failure_message: Mapped[str] = mapped_column(Text)
|
||||
is_resolved: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# The index of the batch where the error occurred (if looping thru batches)
|
||||
# Just informational.
|
||||
batch: Mapped[int | None] = mapped_column(Integer, default=None)
|
||||
doc_summaries: Mapped[list[Any]] = mapped_column(postgresql.JSONB())
|
||||
error_msg: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
traceback: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
@@ -941,6 +913,21 @@ class IndexAttemptError(Base):
|
||||
# This is the reverse side of the relationship
|
||||
index_attempt = relationship("IndexAttempt", back_populates="error_rows")
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"index_attempt_id",
|
||||
"time_created",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<IndexAttempt(id={self.id!r}, "
|
||||
f"index_attempt_id={self.index_attempt_id!r}, "
|
||||
f"error_msg={self.error_msg!r})>"
|
||||
f"time_created={self.time_created!r}, "
|
||||
)
|
||||
|
||||
|
||||
class SyncRecord(Base):
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import time
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
@@ -221,49 +217,3 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
deployment_name=search_settings.deployment_name,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
||||
def embed_chunks_with_failure_handling(
|
||||
chunks: list[DocAwareChunk],
|
||||
embedder: IndexingEmbedder,
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
"""Tries to embed all chunks in one large batch. If that batch fails for any reason,
|
||||
goes document by document to isolate the failure(s).
|
||||
"""
|
||||
|
||||
# First try to embed all chunks in one batch
|
||||
try:
|
||||
return embedder.embed_chunks(chunks=chunks), []
|
||||
except Exception:
|
||||
logger.exception("Failed to embed chunk batch. Trying individual docs.")
|
||||
# wait a couple seconds to let any rate limits or temporary issues resolve
|
||||
time.sleep(2)
|
||||
|
||||
# Try embedding each document's chunks individually
|
||||
chunks_by_doc: dict[str, list[DocAwareChunk]] = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
chunks_by_doc[chunk.source_document.id].append(chunk)
|
||||
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
failures: list[ConnectorFailure] = []
|
||||
|
||||
for doc_id, chunks_for_doc in chunks_by_doc.items():
|
||||
try:
|
||||
doc_embedded_chunks = embedder.embed_chunks(chunks=chunks_for_doc)
|
||||
embedded_chunks.extend(doc_embedded_chunks)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to embed chunks for document '{doc_id}'")
|
||||
failures.append(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=(
|
||||
chunks_for_doc[0].get_link() if chunks_for_doc else None
|
||||
),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
|
||||
return embedded_chunks, failures
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from typing import Protocol
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import get_access_for_documents
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import INDEXING_EXCEPTION_LIMIT
|
||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.document import fetch_chunk_counts_for_documents
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
@@ -27,6 +29,7 @@ from onyx.db.document import update_docs_updated_at__no_commit
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.document import upsert_documents
|
||||
from onyx.db.document_set import fetch_document_sets_for_documents
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.models import Document as DBDocument
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.tag import create_or_add_document_tag
|
||||
@@ -38,12 +41,10 @@ from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import embed_chunks_with_failure_handling
|
||||
from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
@@ -66,8 +67,6 @@ class IndexingPipelineResult(BaseModel):
|
||||
# number of chunks that were inserted into Vespa
|
||||
total_chunks: int
|
||||
|
||||
failures: list[ConnectorFailure]
|
||||
|
||||
|
||||
class IndexingPipelineProtocol(Protocol):
|
||||
def __call__(
|
||||
@@ -157,10 +156,14 @@ def index_doc_batch_with_handler(
|
||||
document_index: DocumentIndex,
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
attempt_id: int | None,
|
||||
db_session: Session,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
) -> IndexingPipelineResult:
|
||||
index_pipeline_result = IndexingPipelineResult(
|
||||
new_docs=0, total_docs=len(document_batch), total_chunks=0
|
||||
)
|
||||
try:
|
||||
index_pipeline_result = index_doc_batch(
|
||||
chunker=chunker,
|
||||
@@ -173,25 +176,47 @@ def index_doc_batch_with_handler(
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to index document batch: {document_batch}")
|
||||
index_pipeline_result = IndexingPipelineResult(
|
||||
new_docs=0,
|
||||
total_docs=len(document_batch),
|
||||
total_chunks=0,
|
||||
failures=[
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=document.id,
|
||||
document_link=(
|
||||
document.sections[0].link if document.sections else None
|
||||
),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE:
|
||||
logger.error(
|
||||
"NOTE: HTTP Status 507 Insufficient Storage indicates "
|
||||
"you need to allocate more memory or disk space to the "
|
||||
"Vespa/index container."
|
||||
)
|
||||
for document in document_batch
|
||||
],
|
||||
|
||||
if INDEXING_EXCEPTION_LIMIT == 0:
|
||||
raise
|
||||
|
||||
trace = traceback.format_exc()
|
||||
create_index_attempt_error(
|
||||
attempt_id,
|
||||
batch=index_attempt_metadata.batch_num,
|
||||
docs=document_batch,
|
||||
exception_msg=str(e),
|
||||
exception_traceback=trace,
|
||||
db_session=db_session,
|
||||
)
|
||||
logger.exception(
|
||||
f"Indexing batch {index_attempt_metadata.batch_num} failed. msg='{e}' trace='{trace}'"
|
||||
)
|
||||
|
||||
index_attempt_metadata.num_exceptions += 1
|
||||
if index_attempt_metadata.num_exceptions == INDEXING_EXCEPTION_LIMIT:
|
||||
logger.warning(
|
||||
f"Maximum number of exceptions for this index attempt "
|
||||
f"({INDEXING_EXCEPTION_LIMIT}) has been reached. "
|
||||
f"The next exception will abort the indexing attempt."
|
||||
)
|
||||
elif index_attempt_metadata.num_exceptions > INDEXING_EXCEPTION_LIMIT:
|
||||
logger.warning(
|
||||
f"Maximum number of exceptions for this index attempt "
|
||||
f"({INDEXING_EXCEPTION_LIMIT}) has been exceeded."
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Maximum exception limit of {INDEXING_EXCEPTION_LIMIT} exceeded."
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
return index_pipeline_result
|
||||
|
||||
@@ -351,12 +376,8 @@ def index_doc_batch(
|
||||
document_ids=[doc.id for doc in filtered_documents],
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
return IndexingPipelineResult(
|
||||
new_docs=0,
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=0,
|
||||
failures=[],
|
||||
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
|
||||
)
|
||||
|
||||
doc_descriptors = [
|
||||
@@ -369,19 +390,10 @@ def index_doc_batch(
|
||||
logger.debug(f"Starting indexing process for documents: {doc_descriptors}")
|
||||
|
||||
logger.debug("Starting chunking")
|
||||
# NOTE: no special handling for failures here, since the chunker is not
|
||||
# a common source of failure for the indexing pipeline
|
||||
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings, embedding_failures = (
|
||||
embed_chunks_with_failure_handling(
|
||||
chunks=chunks,
|
||||
embedder=embedder,
|
||||
)
|
||||
if chunks
|
||||
else ([], [])
|
||||
)
|
||||
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []
|
||||
|
||||
updatable_ids = [doc.id for doc in ctx.updatable_docs]
|
||||
|
||||
@@ -447,11 +459,7 @@ def index_doc_batch(
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
# in this set
|
||||
(
|
||||
insertion_records,
|
||||
vector_db_write_failures,
|
||||
) = write_chunks_to_vector_db_with_backoff(
|
||||
document_index=document_index,
|
||||
insertion_records = document_index.index(
|
||||
chunks=access_aware_chunks,
|
||||
index_batch_params=IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
|
||||
@@ -511,7 +519,6 @@ def index_doc_batch(
|
||||
new_docs=len([r for r in insertion_records if r.already_existed is False]),
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=len(access_aware_chunks),
|
||||
failures=vector_db_write_failures + embedding_failures,
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -524,6 +531,7 @@ def build_indexing_pipeline(
|
||||
db_session: Session,
|
||||
chunker: Chunker | None = None,
|
||||
ignore_time_skip: bool = False,
|
||||
attempt_id: int | None = None,
|
||||
tenant_id: str | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
@@ -545,6 +553,7 @@ def build_indexing_pipeline(
|
||||
embedder=embedder,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
attempt_id=attempt_id,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
@@ -57,13 +57,6 @@ class DocAwareChunk(BaseChunk):
|
||||
"""Used when logging the identity of a chunk"""
|
||||
return f"{self.source_document.to_short_descriptor()} Chunk ID: {self.chunk_id}"
|
||||
|
||||
def get_link(self) -> str | None:
|
||||
return (
|
||||
self.source_document.sections[0].link
|
||||
if self.source_document.sections
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
class IndexChunk(DocAwareChunk):
|
||||
embeddings: ChunkEmbedding
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _log_insufficient_storage_error(e: Exception) -> None:
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE:
|
||||
logger.error(
|
||||
"NOTE: HTTP Status 507 Insufficient Storage indicates "
|
||||
"you need to allocate more memory or disk space to the "
|
||||
"Vespa/index container."
|
||||
)
|
||||
|
||||
|
||||
def write_chunks_to_vector_db_with_backoff(
|
||||
document_index: DocumentIndex,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]:
|
||||
"""Tries to insert all chunks in one large batch. If that batch fails for any reason,
|
||||
goes document by document to isolate the failure(s).
|
||||
|
||||
IMPORTANT: must pass in whole documents at a time not individual chunks, since the
|
||||
vector DB interface assumes that all chunks for a single document are present.
|
||||
"""
|
||||
|
||||
# first try to write the chunks to the vector db
|
||||
try:
|
||||
return (
|
||||
list(
|
||||
document_index.index(
|
||||
chunks=chunks,
|
||||
index_batch_params=index_batch_params,
|
||||
)
|
||||
),
|
||||
[],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to write chunk batch to vector db. Trying individual docs."
|
||||
)
|
||||
|
||||
# give some specific logging on this common failure case.
|
||||
_log_insufficient_storage_error(e)
|
||||
|
||||
# wait a couple seconds just to give the vector db a chance to recover
|
||||
time.sleep(2)
|
||||
|
||||
# try writing each doc one by one
|
||||
chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
chunks_for_docs[chunk.source_document.id].append(chunk)
|
||||
|
||||
insertion_records: list[DocumentInsertionRecord] = []
|
||||
failures: list[ConnectorFailure] = []
|
||||
for doc_id, chunks_for_doc in chunks_for_docs.items():
|
||||
try:
|
||||
insertion_records.extend(
|
||||
document_index.index(
|
||||
chunks=chunks_for_doc,
|
||||
index_batch_params=index_batch_params,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to write document chunks for '{doc_id}' to vector db"
|
||||
)
|
||||
|
||||
# give some specific logging on this common failure case.
|
||||
_log_insufficient_storage_error(e)
|
||||
|
||||
failures.append(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=(
|
||||
chunks_for_doc[0].get_link() if chunks_for_doc else None
|
||||
),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
|
||||
return insertion_records, failures
|
||||
@@ -409,6 +409,10 @@ class DefaultMultiLLM(LLM):
|
||||
self._record_call(processed_prompt)
|
||||
|
||||
try:
|
||||
print(
|
||||
"model is",
|
||||
f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
|
||||
)
|
||||
return litellm.completion(
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
# model choice
|
||||
|
||||
@@ -51,6 +51,7 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
|
||||
from onyx.server.documents.connector import router as connector_router
|
||||
from onyx.server.documents.credential import router as credential_router
|
||||
from onyx.server.documents.document import router as document_router
|
||||
from onyx.server.documents.indexing import router as indexing_router
|
||||
from onyx.server.documents.standard_oauth import router as oauth_router
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.folder.api import router as folder_router
|
||||
@@ -316,6 +317,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, token_rate_limit_settings_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, indexing_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, get_full_openai_assistants_api_router()
|
||||
)
|
||||
|
||||
@@ -61,10 +61,10 @@ def _create_indexable_chunks(
|
||||
doc_updated_at=None,
|
||||
primary_owners=[],
|
||||
secondary_owners=[],
|
||||
chunk_count=preprocessed_doc["chunk_ind"] + 1,
|
||||
chunk_count=1,
|
||||
)
|
||||
|
||||
ids_to_documents[document.id] = document
|
||||
if preprocessed_doc["chunk_ind"] == 0:
|
||||
ids_to_documents[document.id] = document
|
||||
|
||||
chunk = DocMetadataAwareIndexChunk(
|
||||
chunk_id=preprocessed_doc["chunk_ind"],
|
||||
@@ -92,7 +92,6 @@ def _create_indexable_chunks(
|
||||
boost=DEFAULT_BOOST,
|
||||
large_chunk_id=None,
|
||||
)
|
||||
|
||||
chunks.append(chunk)
|
||||
|
||||
return list(ids_to_documents.values()), chunks
|
||||
@@ -193,7 +192,6 @@ def seed_initial_documents(
|
||||
last_successful_index_time=last_index_time,
|
||||
seeding_flow=True,
|
||||
)
|
||||
|
||||
cc_pair_id = cast(int, result.data)
|
||||
processed_docs = fetch_versioned_implementation(
|
||||
"onyx.seeding.load_docs",
|
||||
@@ -251,5 +249,4 @@ def seed_initial_documents(
|
||||
.values(chunk_count=doc.chunk_count)
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
kv_store.store(KV_DOCUMENTS_SEEDED_KEY, True)
|
||||
|
||||
@@ -22,7 +22,6 @@ from onyx.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.background.indexing.models import IndexAttemptErrorPydantic
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.connector_credential_pair import add_credential_to_connector
|
||||
@@ -40,9 +39,7 @@ from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import count_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import count_index_attempts_for_connector
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from onyx.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from onyx.db.models import SearchSettings
|
||||
@@ -549,47 +546,6 @@ def get_docs_sync_status(
|
||||
return [DocumentSyncStatus.from_model(doc) for doc in all_docs_for_cc_pair]
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/errors")
|
||||
def get_cc_pair_indexing_errors(
|
||||
cc_pair_id: int,
|
||||
include_resolved: bool = Query(False),
|
||||
page: int = Query(0, ge=0),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[IndexAttemptErrorPydantic]:
|
||||
"""Gives back all errors for a given CC Pair. Allows pagination based on page and page_size params.
|
||||
|
||||
Args:
|
||||
cc_pair_id: ID of the connector-credential pair to get errors for
|
||||
include_resolved: Whether to include resolved errors in the results
|
||||
page: Page number for pagination, starting at 0
|
||||
page_size: Number of errors to return per page
|
||||
_: Current user, must be curator or admin
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
Paginated list of indexing errors for the CC pair.
|
||||
"""
|
||||
total_count = count_index_attempt_errors_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
unresolved_only=not include_resolved,
|
||||
)
|
||||
|
||||
index_attempt_errors = get_index_attempt_errors_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
unresolved_only=not include_resolved,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return PaginatedReturn(
|
||||
items=[IndexAttemptErrorPydantic.from_model(e) for e in index_attempt_errors],
|
||||
total_items=total_count,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/connector/{connector_id}/credential/{credential_id}")
|
||||
def associate_credential_to_connector(
|
||||
connector_id: int,
|
||||
|
||||
@@ -22,7 +22,6 @@ from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.configs.app_configs import ENABLED_CONNECTOR_TYPES
|
||||
from onyx.configs.app_configs import MOCK_CONNECTOR_FILE_PATH
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
@@ -614,16 +613,6 @@ def get_connector_indexing_status(
|
||||
) -> list[ConnectorIndexingStatus]:
|
||||
indexing_statuses: list[ConnectorIndexingStatus] = []
|
||||
|
||||
if MOCK_CONNECTOR_FILE_PATH:
|
||||
import json
|
||||
|
||||
with open(MOCK_CONNECTOR_FILE_PATH, "r") as f:
|
||||
raw_data = json.load(f)
|
||||
connector_indexing_statuses = [
|
||||
ConnectorIndexingStatus(**status) for status in raw_data
|
||||
]
|
||||
return connector_indexing_statuses
|
||||
|
||||
# NOTE: If the connector is deleting behind the scenes,
|
||||
# accessing cc_pairs can be inconsistent and members like
|
||||
# connector or credential may be None.
|
||||
|
||||
23
backend/onyx/server/documents/indexing.py
Normal file
23
backend/onyx/server/documents/indexing.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.index_attempt import (
|
||||
get_index_attempt_errors,
|
||||
)
|
||||
from onyx.db.models import User
|
||||
from onyx.server.documents.models import IndexAttemptError
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
||||
@router.get("/admin/indexing-errors/{index_attempt_id}")
|
||||
def get_indexing_errors(
|
||||
index_attempt_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[IndexAttemptError]:
|
||||
indexing_errors = get_index_attempt_errors(index_attempt_id, db_session)
|
||||
return [IndexAttemptError.from_db_model(e) for e in indexing_errors]
|
||||
@@ -8,9 +8,9 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from ee.onyx.server.query_history.models import ChatSessionMinimal
|
||||
from onyx.background.indexing.models import IndexAttemptErrorPydantic
|
||||
from onyx.configs.app_configs import MASK_CREDENTIAL_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import DocumentErrorSummary
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -19,6 +19,7 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Document as DbDocument
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError as DbIndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import TaskStatus
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
@@ -149,7 +150,6 @@ class CredentialSnapshot(CredentialBase):
|
||||
class IndexAttemptSnapshot(BaseModel):
|
||||
id: int
|
||||
status: IndexingStatus | None
|
||||
from_beginning: bool
|
||||
new_docs_indexed: int # only includes completely new docs
|
||||
total_docs_indexed: int # includes docs that are updated
|
||||
docs_removed_from_index: int
|
||||
@@ -166,7 +166,6 @@ class IndexAttemptSnapshot(BaseModel):
|
||||
return IndexAttemptSnapshot(
|
||||
id=index_attempt.id,
|
||||
status=index_attempt.status,
|
||||
from_beginning=index_attempt.from_beginning,
|
||||
new_docs_indexed=index_attempt.new_docs_indexed or 0,
|
||||
total_docs_indexed=index_attempt.total_docs_indexed or 0,
|
||||
docs_removed_from_index=index_attempt.docs_removed_from_index or 0,
|
||||
@@ -182,6 +181,31 @@ class IndexAttemptSnapshot(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class IndexAttemptError(BaseModel):
|
||||
id: int
|
||||
index_attempt_id: int | None
|
||||
batch_number: int | None
|
||||
doc_summaries: list[DocumentErrorSummary]
|
||||
error_msg: str | None
|
||||
traceback: str | None
|
||||
time_created: str
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, error: DbIndexAttemptError) -> "IndexAttemptError":
|
||||
doc_summaries = [
|
||||
DocumentErrorSummary.from_dict(summary) for summary in error.doc_summaries
|
||||
]
|
||||
return IndexAttemptError(
|
||||
id=error.id,
|
||||
index_attempt_id=error.index_attempt_id,
|
||||
batch_number=error.batch,
|
||||
doc_summaries=doc_summaries,
|
||||
error_msg=error.error_msg,
|
||||
traceback=error.traceback,
|
||||
time_created=error.time_created.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
# These are the types currently supported by the pagination hook
|
||||
# More api endpoints can be refactored and be added here for use with the pagination hook
|
||||
PaginatedType = TypeVar(
|
||||
@@ -190,7 +214,6 @@ PaginatedType = TypeVar(
|
||||
FullUserSnapshot,
|
||||
InvitedUserSnapshot,
|
||||
ChatSessionMinimal,
|
||||
IndexAttemptErrorPydantic,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -213,6 +213,8 @@ def get_chat_session(
|
||||
# we need the tool call objs anyways, so just fetch them in a single call
|
||||
prefetch_tool_calls=True,
|
||||
)
|
||||
for message in session_messages:
|
||||
translate_db_message_to_chat_message_detail(message)
|
||||
|
||||
return ChatSessionDetailResponse(
|
||||
chat_session_id=session_id,
|
||||
|
||||
@@ -58,7 +58,6 @@ SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
|
||||
SEARCH_DOC_CONTENT_ID = "search_doc_content"
|
||||
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
|
||||
SEARCH_EVALUATION_ID = "llm_doc_eval"
|
||||
QUERY_FIELD = "query"
|
||||
|
||||
|
||||
class SearchResponseSummary(SearchQueryInfo):
|
||||
@@ -180,12 +179,12 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
QUERY_FIELD: {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "What to search for",
|
||||
},
|
||||
},
|
||||
"required": [QUERY_FIELD],
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -224,7 +223,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
rephrased_query = history_based_query_rephrase(
|
||||
query=query, history=history, llm=llm
|
||||
)
|
||||
return {QUERY_FIELD: rephrased_query}
|
||||
return {"query": rephrased_query}
|
||||
|
||||
"""Actual tool execution"""
|
||||
|
||||
@@ -280,7 +279,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
def run(
|
||||
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, llm_kwargs[QUERY_FIELD])
|
||||
query = cast(str, llm_kwargs["query"])
|
||||
force_no_rerank = False
|
||||
alternate_db_session = None
|
||||
retrieved_sections_callback = None
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
import sys
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T", dict, list, tuple, set, frozenset)
|
||||
|
||||
|
||||
def deep_getsizeof(obj: T, seen: set[int] | None = None) -> int:
|
||||
"""Recursively sum size of objects, handling circular references."""
|
||||
if seen is None:
|
||||
seen = set()
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in seen:
|
||||
return 0 # Prevent infinite recursion for circular references
|
||||
|
||||
seen.add(obj_id)
|
||||
size = sys.getsizeof(obj)
|
||||
|
||||
if isinstance(obj, dict):
|
||||
size += sum(
|
||||
deep_getsizeof(k, seen) + deep_getsizeof(v, seen) for k, v in obj.items()
|
||||
)
|
||||
elif isinstance(obj, (list, tuple, set, frozenset)):
|
||||
size += sum(deep_getsizeof(i, seen) for i in obj)
|
||||
|
||||
return size
|
||||
@@ -1,4 +1,3 @@
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import as_completed
|
||||
@@ -14,10 +13,6 @@ logger = setup_logger()
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
||||
# is not safe, update this comment.
|
||||
def run_functions_tuples_in_parallel(
|
||||
functions_with_args: list[tuple[Callable, tuple]],
|
||||
allow_failures: bool = False,
|
||||
@@ -83,10 +78,6 @@ class FunctionCall(Generic[R]):
|
||||
return self.func(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
||||
# is not safe, update this comment.
|
||||
def run_functions_in_parallel(
|
||||
function_calls: list[FunctionCall],
|
||||
allow_failures: bool = False,
|
||||
@@ -118,49 +109,3 @@ def run_functions_in_parallel(
|
||||
raise
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class TimeoutThread(threading.Thread):
|
||||
def __init__(
|
||||
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
):
|
||||
super().__init__()
|
||||
self.timeout = timeout
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.exception: Exception | None = None
|
||||
|
||||
def run(self) -> None:
|
||||
try:
|
||||
self.result = self.func(*self.args, **self.kwargs)
|
||||
except Exception as e:
|
||||
self.exception = e
|
||||
|
||||
def end(self) -> None:
|
||||
raise TimeoutError(
|
||||
f"Function {self.func.__name__} timed out after {self.timeout} seconds"
|
||||
)
|
||||
|
||||
|
||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
||||
# is not safe, update this comment.
|
||||
def run_with_timeout(
|
||||
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
) -> R:
|
||||
"""
|
||||
Executes a function with a timeout. If the function doesn't complete within the specified
|
||||
timeout, raises TimeoutError.
|
||||
"""
|
||||
task = TimeoutThread(timeout, func, *args, **kwargs)
|
||||
task.start()
|
||||
task.join(timeout)
|
||||
|
||||
if task.exception is not None:
|
||||
raise task.exception
|
||||
if task.is_alive():
|
||||
task.end()
|
||||
|
||||
return task.result
|
||||
|
||||
@@ -42,7 +42,7 @@ def run_jobs() -> None:
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
]
|
||||
|
||||
cmd_worker_heavy = [
|
||||
|
||||
@@ -33,7 +33,7 @@ stopasgroup=true
|
||||
command=celery -A onyx.background.celery.versioned_apps.light worker
|
||||
--loglevel=INFO
|
||||
--hostname=light@%%n
|
||||
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup
|
||||
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert
|
||||
stdout_logfile=/var/log/celery_worker_light.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
def test_create_chat_session_and_send_messages() -> None:
|
||||
def test_create_chat_session_and_send_messages(db_session: Session) -> None:
|
||||
# Create a test user
|
||||
with get_session_context_manager() as db_session:
|
||||
test_user = User(email="test@example.com", hashed_password="dummy_hash")
|
||||
db_session.add(test_user)
|
||||
db_session.commit()
|
||||
test_user = User(email="test@example.com", hashed_password="dummy_hash")
|
||||
db_session.add(test_user)
|
||||
db_session.commit()
|
||||
|
||||
base_url = "http://localhost:8080" # Adjust this to your API's base URL
|
||||
headers = {"Authorization": f"Bearer {test_user.id}"}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
|
||||
ADMIN_USER_NAME = "admin_user"
|
||||
|
||||
API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
|
||||
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
|
||||
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
|
||||
@@ -11,6 +9,3 @@ MAX_DELAY = 45
|
||||
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
NUM_DOCS = 5
|
||||
|
||||
MOCK_CONNECTOR_SERVER_HOST = os.getenv("MOCK_CONNECTOR_SERVER_HOST") or "localhost"
|
||||
MOCK_CONNECTOR_SERVER_PORT = os.getenv("MOCK_CONNECTOR_SERVER_PORT") or 8001
|
||||
|
||||
@@ -223,13 +223,12 @@ class CCPairManager:
|
||||
@staticmethod
|
||||
def run_once(
|
||||
cc_pair: DATestCCPair,
|
||||
from_beginning: bool,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
body = {
|
||||
"connector_id": cc_pair.connector_id,
|
||||
"credential_ids": [cc_pair.credential_id],
|
||||
"from_beginning": from_beginning,
|
||||
"from_beginning": True,
|
||||
}
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector/run-once",
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
@@ -191,39 +186,3 @@ class DocumentManager:
|
||||
group_names,
|
||||
doc_creating_user,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fetch_documents_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
vespa_client: vespa_fixture,
|
||||
) -> list[SimpleTestDocument]:
|
||||
stmt = (
|
||||
select(DocumentByConnectorCredentialPair)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
)
|
||||
documents = db_session.execute(stmt).scalars().all()
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
doc_ids = [document.id for document in documents]
|
||||
retrieved_docs_dict = vespa_client.get_documents_by_id(doc_ids)["documents"]
|
||||
|
||||
final_docs: list[SimpleTestDocument] = []
|
||||
# NOTE: they are really chunks, but we're assuming that for these tests
|
||||
# we only have one chunk per document for now
|
||||
for doc_dict in retrieved_docs_dict:
|
||||
doc_id = doc_dict["fields"]["document_id"]
|
||||
doc_content = doc_dict["fields"]["content"]
|
||||
final_docs.append(SimpleTestDocument(id=doc_id, content=doc_content))
|
||||
|
||||
return final_docs
|
||||
|
||||
@@ -4,7 +4,6 @@ from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.background.indexing.models import IndexAttemptErrorPydantic
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.models import IndexAttempt
|
||||
@@ -14,7 +13,6 @@ from onyx.server.documents.models import IndexAttemptSnapshot
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import DATestIndexAttempt
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -94,12 +92,8 @@ class IndexAttemptManager:
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
url = (
|
||||
f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/index-attempts"
|
||||
f"?{urlencode(query_params, doseq=True)}"
|
||||
)
|
||||
response = requests.get(
|
||||
url=url,
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/index-attempts?{urlencode(query_params, doseq=True)}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
@@ -110,125 +104,3 @@ class IndexAttemptManager:
|
||||
items=[IndexAttemptSnapshot(**item) for item in data["items"]],
|
||||
total_items=data["total_items"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_latest_index_attempt_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> IndexAttemptSnapshot | None:
|
||||
"""Get an IndexAttempt by ID"""
|
||||
index_attempts = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id, user_performing_action=user_performing_action
|
||||
).items
|
||||
if not index_attempts:
|
||||
return None
|
||||
|
||||
index_attempts = sorted(
|
||||
index_attempts, key=lambda x: x.time_started or "0", reverse=True
|
||||
)
|
||||
return index_attempts[0]
|
||||
|
||||
@staticmethod
|
||||
def wait_for_index_attempt_start(
|
||||
cc_pair_id: int,
|
||||
index_attempts_to_ignore: list[int] | None = None,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> IndexAttemptSnapshot:
|
||||
"""Wait for an IndexAttempt to start"""
|
||||
start = datetime.now()
|
||||
index_attempts_to_ignore = index_attempts_to_ignore or []
|
||||
|
||||
while True:
|
||||
index_attempt = IndexAttemptManager.get_latest_index_attempt_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
if (
|
||||
index_attempt
|
||||
and index_attempt.time_started
|
||||
and index_attempt.id not in index_attempts_to_ignore
|
||||
):
|
||||
return index_attempt
|
||||
|
||||
elapsed = (datetime.now() - start).total_seconds()
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"IndexAttempt for CC Pair {cc_pair_id} did not start within {timeout} seconds"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_index_attempt_by_id(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> IndexAttemptSnapshot:
|
||||
page_num = 0
|
||||
page_size = 10
|
||||
while True:
|
||||
page = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id=cc_pair_id,
|
||||
page=page_num,
|
||||
page_size=page_size,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
for attempt in page.items:
|
||||
if attempt.id == index_attempt_id:
|
||||
return attempt
|
||||
|
||||
if len(page.items) < page_size:
|
||||
break
|
||||
|
||||
page_num += 1
|
||||
|
||||
raise ValueError(f"IndexAttempt {index_attempt_id} not found")
|
||||
|
||||
@staticmethod
|
||||
def wait_for_index_attempt_completion(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""Wait for an IndexAttempt to complete"""
|
||||
start = datetime.now()
|
||||
while True:
|
||||
index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
|
||||
if index_attempt.status and index_attempt.status.is_terminal():
|
||||
print(f"IndexAttempt {index_attempt_id} completed")
|
||||
return
|
||||
|
||||
elapsed = (datetime.now() - start).total_seconds()
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"IndexAttempt {index_attempt_id} did not complete within {timeout} seconds"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Waiting for IndexAttempt {index_attempt_id} to complete. "
|
||||
f"elapsed={elapsed:.2f} timeout={timeout}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
include_resolved: bool = True,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[IndexAttemptErrorPydantic]:
|
||||
url = f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/errors?page_size=100"
|
||||
if include_resolved:
|
||||
url += "&include_resolved=true"
|
||||
response = requests.get(
|
||||
url=url,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return [IndexAttemptErrorPydantic(**item) for item in data["items"]]
|
||||
|
||||
@@ -25,7 +25,6 @@ from onyx.indexing.models import IndexingSetting
|
||||
from onyx.setup import setup_postgres
|
||||
from onyx.setup import setup_vespa
|
||||
from onyx.utils.logger import setup_logger
|
||||
from tests.integration.common_utils.timeout import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -67,7 +66,6 @@ def _run_migrations(
|
||||
|
||||
def downgrade_postgres(
|
||||
database: str = "postgres",
|
||||
schema: str = "public",
|
||||
config_name: str = "alembic",
|
||||
revision: str = "base",
|
||||
clear_data: bool = False,
|
||||
@@ -75,8 +73,8 @@ def downgrade_postgres(
|
||||
"""Downgrade Postgres database to base state."""
|
||||
if clear_data:
|
||||
if revision != "base":
|
||||
raise ValueError("Clearing data without rolling back to base state")
|
||||
|
||||
logger.warning("Clearing data without rolling back to base state")
|
||||
# Delete all rows to allow migrations to be rolled back
|
||||
conn = psycopg2.connect(
|
||||
dbname=database,
|
||||
user=POSTGRES_USER,
|
||||
@@ -84,33 +82,38 @@ def downgrade_postgres(
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
)
|
||||
conn.autocommit = True # Need autocommit for dropping schema
|
||||
cur = conn.cursor()
|
||||
|
||||
# Close any existing connections to the schema before dropping
|
||||
# Disable triggers to prevent foreign key constraints from being checked
|
||||
cur.execute("SET session_replication_role = 'replica';")
|
||||
|
||||
# Fetch all table names in the current database
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
||||
FROM pg_stat_activity
|
||||
WHERE pg_stat_activity.datname = '{database}'
|
||||
AND pg_stat_activity.state = 'idle in transaction'
|
||||
AND pid <> pg_backend_pid();
|
||||
"""
|
||||
SELECT tablename
|
||||
FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop and recreate the public schema - this removes ALL objects
|
||||
cur.execute(f"DROP SCHEMA {schema} CASCADE;")
|
||||
cur.execute(f"CREATE SCHEMA {schema};")
|
||||
tables = cur.fetchall()
|
||||
|
||||
# Restore default privileges
|
||||
cur.execute(f"GRANT ALL ON SCHEMA {schema} TO postgres;")
|
||||
cur.execute(f"GRANT ALL ON SCHEMA {schema} TO public;")
|
||||
for table in tables:
|
||||
table_name = table[0]
|
||||
|
||||
# Don't touch migration history or Kombu
|
||||
if table_name in ("alembic_version", "kombu_message", "kombu_queue"):
|
||||
continue
|
||||
|
||||
cur.execute(f'DELETE FROM "{table_name}"')
|
||||
|
||||
# Re-enable triggers
|
||||
cur.execute("SET session_replication_role = 'origin';")
|
||||
|
||||
conn.commit()
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
return
|
||||
|
||||
# Downgrade to base
|
||||
conn_str = build_connection_string(
|
||||
db=database,
|
||||
@@ -154,37 +157,11 @@ def reset_postgres(
|
||||
setup_onyx: bool = True,
|
||||
) -> None:
|
||||
"""Reset the Postgres database."""
|
||||
# this seems to hang due to locking issues, so run with a timeout with a few retries
|
||||
NUM_TRIES = 10
|
||||
TIMEOUT = 10
|
||||
success = False
|
||||
for _ in range(NUM_TRIES):
|
||||
logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})")
|
||||
try:
|
||||
run_with_timeout(
|
||||
downgrade_postgres,
|
||||
TIMEOUT,
|
||||
kwargs={
|
||||
"database": database,
|
||||
"config_name": config_name,
|
||||
"revision": "base",
|
||||
"clear_data": True,
|
||||
},
|
||||
)
|
||||
success = True
|
||||
break
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
f"Postgres downgrade timed out, retrying... ({_ + 1}/{NUM_TRIES})"
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise RuntimeError("Postgres downgrade failed after 10 timeouts.")
|
||||
|
||||
logger.info("Upgrading Postgres...")
|
||||
downgrade_postgres(
|
||||
database=database, config_name=config_name, revision="base", clear_data=True
|
||||
)
|
||||
upgrade_postgres(database=database, config_name=config_name, revision="head")
|
||||
if setup_onyx:
|
||||
logger.info("Setting up Postgres...")
|
||||
with get_session_context_manager() as db_session:
|
||||
setup_postgres(db_session)
|
||||
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import Section
|
||||
|
||||
|
||||
def create_test_document(
|
||||
doc_id: str | None = None,
|
||||
text: str = "Test content",
|
||||
link: str = "http://test.com",
|
||||
source: DocumentSource = DocumentSource.MOCK_CONNECTOR,
|
||||
metadata: dict | None = None,
|
||||
) -> Document:
|
||||
"""Create a test document with the given parameters.
|
||||
|
||||
Args:
|
||||
doc_id: Optional document ID. If not provided, a random UUID will be generated.
|
||||
text: The text content of the document. Defaults to "Test content".
|
||||
link: The link for the document section. Defaults to "http://test.com".
|
||||
source: The document source. Defaults to MOCK_CONNECTOR.
|
||||
metadata: Optional metadata dictionary. Defaults to empty dict.
|
||||
"""
|
||||
doc_id = doc_id or f"test-doc-{uuid.uuid4()}"
|
||||
return Document(
|
||||
id=doc_id,
|
||||
sections=[Section(text=text, link=link)],
|
||||
source=source,
|
||||
semantic_identifier=doc_id,
|
||||
doc_updated_at=datetime.now(timezone.utc),
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
|
||||
def create_test_document_failure(
|
||||
doc_id: str,
|
||||
failure_message: str = "Simulated failure",
|
||||
document_link: str | None = None,
|
||||
) -> ConnectorFailure:
|
||||
"""Create a test document failure with the given parameters.
|
||||
|
||||
Args:
|
||||
doc_id: The ID of the document that failed.
|
||||
failure_message: The failure message. Defaults to "Simulated failure".
|
||||
document_link: Optional link to the failed document.
|
||||
"""
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=document_link,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
)
|
||||
@@ -1,18 +0,0 @@
|
||||
import multiprocessing
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def run_with_timeout(task: Callable[..., T], timeout: int, kwargs: dict[str, Any]) -> T:
|
||||
# Use multiprocessing to prevent a thread from blocking the main thread
|
||||
with multiprocessing.Pool(processes=1) as pool:
|
||||
async_result = pool.apply_async(task, kwds=kwargs)
|
||||
try:
|
||||
# Wait at most timeout seconds for the function to complete
|
||||
result = async_result.get(timeout=timeout)
|
||||
return result
|
||||
except multiprocessing.TimeoutError:
|
||||
raise TimeoutError(f"Function timed out after {timeout} seconds")
|
||||
@@ -1,11 +1,12 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
@@ -35,24 +36,16 @@ def load_env_vars(env_file: str = ".env") -> None:
|
||||
load_env_vars()
|
||||
|
||||
|
||||
"""NOTE: for some reason using this seems to lead to misc
|
||||
`sqlalchemy.exc.OperationalError: (psycopg2.OperationalError) server closed the connection unexpectedly`
|
||||
errors.
|
||||
|
||||
Commenting out till we can get to the bottom of it. For now, just using
|
||||
instantiate the session directly within the test.
|
||||
"""
|
||||
# @pytest.fixture
|
||||
# def db_session() -> Generator[Session, None, None]:
|
||||
# with get_session_context_manager() as session:
|
||||
# yield session
|
||||
@pytest.fixture
|
||||
def db_session() -> Generator[Session, None, None]:
|
||||
with get_session_context_manager() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vespa_client() -> vespa_fixture:
|
||||
with get_session_context_manager() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
return vespa_fixture(index_name=search_settings.index_name)
|
||||
def vespa_client(db_session: Session) -> vespa_fixture:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
return vespa_fixture(index_name=search_settings.index_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -63,27 +56,20 @@ def reset() -> None:
|
||||
@pytest.fixture
|
||||
def new_admin_user(reset: None) -> DATestUser | None:
|
||||
try:
|
||||
return UserManager.create(name=ADMIN_USER_NAME)
|
||||
return UserManager.create(name="admin_user")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user() -> DATestUser:
|
||||
def admin_user() -> DATestUser | None:
|
||||
try:
|
||||
user = UserManager.create(name=ADMIN_USER_NAME, is_first_user=True)
|
||||
|
||||
# if there are other users for some reason, reset and try again
|
||||
if not UserManager.is_role(user, UserRole.ADMIN):
|
||||
print("Trying to reset")
|
||||
reset_all()
|
||||
user = UserManager.create(name=ADMIN_USER_NAME)
|
||||
return user
|
||||
except Exception as e:
|
||||
print(f"Failed to create admin user: {e}")
|
||||
return UserManager.create(name="admin_user")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
user = UserManager.login_as_user(
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("admin_user"),
|
||||
@@ -93,16 +79,10 @@ def admin_user() -> DATestUser:
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
if not UserManager.is_role(user, UserRole.ADMIN):
|
||||
reset_all()
|
||||
user = UserManager.create(name=ADMIN_USER_NAME)
|
||||
return user
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return user
|
||||
except Exception as e:
|
||||
print(f"Failed to create or login as admin user: {e}")
|
||||
|
||||
raise RuntimeError("Failed to create or login as admin user")
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -138,9 +138,7 @@ def test_google_permission_sync(
|
||||
GoogleDriveManager.append_text_to_doc(drive_service, doc_id_1, doc_text_1)
|
||||
|
||||
# run indexing
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair, after=before, user_performing_action=admin_user
|
||||
)
|
||||
@@ -186,9 +184,7 @@ def test_google_permission_sync(
|
||||
GoogleDriveManager.append_text_to_doc(drive_service, doc_id_2, doc_text_2)
|
||||
|
||||
# Run indexing
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
|
||||
@@ -113,9 +113,7 @@ def test_slack_permission_sync(
|
||||
|
||||
# Run indexing
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
@@ -307,9 +305,7 @@ def test_slack_group_permission_sync(
|
||||
)
|
||||
|
||||
# Run indexing
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
|
||||
@@ -111,9 +111,7 @@ def test_slack_prune(
|
||||
|
||||
# Run indexing
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
mock_connector_server:
|
||||
build:
|
||||
context: ./mock_connector_server
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8001:8001"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8001/health"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
networks:
|
||||
- onyx-stack_default
|
||||
networks:
|
||||
onyx-stack_default:
|
||||
name: onyx-stack_default
|
||||
external: true
|
||||
@@ -1,9 +0,0 @@
|
||||
FROM python:3.11.7-slim-bookworm
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN pip install fastapi uvicorn
|
||||
|
||||
COPY ./main.py /app/main.py
|
||||
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8001"]
|
||||
@@ -1,76 +0,0 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
# We would like to import these, but it makes building this so much harder/slower
|
||||
# from onyx.connectors.mock_connector.connector import SingleConnectorYield
|
||||
# from onyx.connectors.models import ConnectorCheckpoint
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# Global state to store connector behavior configuration
|
||||
class ConnectorBehavior(BaseModel):
|
||||
connector_yields: list[dict] = Field(
|
||||
default_factory=list
|
||||
) # really list[SingleConnectorYield]
|
||||
called_with_checkpoints: list[dict] = Field(
|
||||
default_factory=list
|
||||
) # really list[ConnectorCheckpoint]
|
||||
|
||||
|
||||
current_behavior: ConnectorBehavior = ConnectorBehavior()
|
||||
|
||||
|
||||
@app.post("/set-behavior")
|
||||
async def set_behavior(behavior: list[dict]) -> None:
|
||||
"""Set the behavior for the next connector run"""
|
||||
global current_behavior
|
||||
current_behavior = ConnectorBehavior(connector_yields=behavior)
|
||||
|
||||
|
||||
@app.get("/get-documents")
|
||||
async def get_documents() -> list[dict]:
|
||||
"""Get the next batch of documents and update the checkpoint"""
|
||||
global current_behavior
|
||||
|
||||
if not current_behavior.connector_yields:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No documents or failures configured"
|
||||
)
|
||||
|
||||
connector_yields = current_behavior.connector_yields
|
||||
|
||||
# Clear the current behavior after returning it
|
||||
current_behavior = ConnectorBehavior()
|
||||
|
||||
return connector_yields
|
||||
|
||||
|
||||
@app.post("/add-checkpoint")
|
||||
async def add_checkpoint(checkpoint: dict) -> None:
|
||||
"""Add a checkpoint to the list of checkpoints. Called by the MockConnector."""
|
||||
global current_behavior
|
||||
current_behavior.called_with_checkpoints.append(checkpoint)
|
||||
|
||||
|
||||
@app.get("/get-checkpoints")
|
||||
async def get_checkpoints() -> list[dict]:
|
||||
"""Get the list of checkpoints. Used by the test to verify the
|
||||
proper checkpoint ordering."""
|
||||
global current_behavior
|
||||
return current_behavior.called_with_checkpoints
|
||||
|
||||
|
||||
@app.post("/reset")
|
||||
async def reset() -> None:
|
||||
"""Reset the connector behavior to default"""
|
||||
global current_behavior
|
||||
current_behavior = ConnectorBehavior()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, str]:
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy"}
|
||||
@@ -9,8 +9,6 @@ from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.index_attempt import create_index_attempt
|
||||
@@ -103,15 +101,10 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
|
||||
create_index_attempt_error(
|
||||
index_attempt_id=new_attempt.id,
|
||||
connector_credential_pair_id=cc_pair_1.id,
|
||||
failure=ConnectorFailure(
|
||||
failure_message="Test error",
|
||||
failed_document=DocumentFailure(
|
||||
document_id=cc_pair_1.documents[0].id,
|
||||
document_link=None,
|
||||
),
|
||||
failed_entity=None,
|
||||
),
|
||||
batch=1,
|
||||
docs=[],
|
||||
exception_msg="",
|
||||
exception_traceback="",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@@ -134,15 +127,10 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
)
|
||||
create_index_attempt_error(
|
||||
index_attempt_id=attempt_id,
|
||||
connector_credential_pair_id=cc_pair_1.id,
|
||||
failure=ConnectorFailure(
|
||||
failure_message="Test error",
|
||||
failed_document=DocumentFailure(
|
||||
document_id=cc_pair_1.documents[0].id,
|
||||
document_link=None,
|
||||
),
|
||||
failed_entity=None,
|
||||
),
|
||||
batch=1,
|
||||
docs=[],
|
||||
exception_msg="",
|
||||
exception_traceback="",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,518 +0,0 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_HOST
|
||||
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_PORT
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
|
||||
from tests.integration.common_utils.test_document_utils import create_test_document
|
||||
from tests.integration.common_utils.test_document_utils import (
|
||||
create_test_document_failure,
|
||||
)
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_server_client() -> httpx.Client:
|
||||
print(
|
||||
f"Initializing mock server client with host: "
|
||||
f"{MOCK_CONNECTOR_SERVER_HOST} and port: "
|
||||
f"{MOCK_CONNECTOR_SERVER_PORT}"
|
||||
)
|
||||
return httpx.Client(
|
||||
base_url=f"http://{MOCK_CONNECTOR_SERVER_HOST}:{MOCK_CONNECTOR_SERVER_PORT}",
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
|
||||
def test_mock_connector_basic_flow(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that the mock connector can successfully process documents and failures"""
|
||||
# Set up mock server behavior
|
||||
doc_uuid = uuid.uuid4()
|
||||
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [test_doc.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# create CC Pair + index attempt
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# wait for index attempt to start
|
||||
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# wait for index attempt to finish
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate status
|
||||
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.SUCCESS
|
||||
|
||||
# Verify results
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id == test_doc.id
|
||||
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 0
|
||||
|
||||
|
||||
def test_mock_connector_with_failures(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that the mock connector processes both successes and failures properly."""
|
||||
doc1 = create_test_document()
|
||||
doc2 = create_test_document()
|
||||
doc2_failure = create_test_document_failure(doc_id=doc2.id)
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [doc1.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [doc2_failure.model_dump(mode="json")],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create a CC Pair for the mock connector
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-failure-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for the index attempt to start and then complete
|
||||
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate status
|
||||
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.COMPLETED_WITH_ERRORS
|
||||
|
||||
# Verify results: doc1 should be indexed and doc2 should have an error entry
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id == doc1.id
|
||||
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 1
|
||||
error = errors[0]
|
||||
assert error.failure_message == doc2_failure.failure_message
|
||||
assert error.document_id == doc2.id
|
||||
|
||||
|
||||
def test_mock_connector_failure_recovery(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that a failed document can be successfully indexed in a subsequent attempt
|
||||
while maintaining previously successful documents."""
|
||||
# Create test documents and failure
|
||||
doc1 = create_test_document()
|
||||
doc2 = create_test_document()
|
||||
doc2_failure = create_test_document_failure(doc_id=doc2.id)
|
||||
entity_id = "test-entity-id"
|
||||
entity_failure_msg = "Simulated unhandled error"
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [doc1.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [
|
||||
doc2_failure.model_dump(mode="json"),
|
||||
ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=entity_id,
|
||||
missed_time_range=(
|
||||
datetime.now(timezone.utc) - timedelta(days=1),
|
||||
datetime.now(timezone.utc),
|
||||
),
|
||||
),
|
||||
failure_message=entity_failure_msg,
|
||||
).model_dump(mode="json"),
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create CC Pair and run initial indexing attempt
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for first index attempt to complete
|
||||
initial_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=initial_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate status
|
||||
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=initial_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.COMPLETED_WITH_ERRORS
|
||||
|
||||
# Verify initial state: doc1 indexed, doc2 failed
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id == doc1.id
|
||||
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 2
|
||||
error_doc2 = next(error for error in errors if error.document_id == doc2.id)
|
||||
assert error_doc2.failure_message == doc2_failure.failure_message
|
||||
assert not error_doc2.is_resolved
|
||||
|
||||
error_entity = next(error for error in errors if error.entity_id == entity_id)
|
||||
assert error_entity.failure_message == entity_failure_msg
|
||||
assert not error_entity.is_resolved
|
||||
|
||||
# Update mock server to return success for both documents
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [
|
||||
doc1.model_dump(mode="json"),
|
||||
doc2.model_dump(mode="json"),
|
||||
],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Trigger another indexing attempt
|
||||
# NOTE: must be from beginning to handle the entity failure
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts_to_ignore=[initial_index_attempt.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=recovery_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
finished_second_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=recovery_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_second_index_attempt.status == IndexingStatus.SUCCESS
|
||||
|
||||
# Verify both documents are now indexed
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 2
|
||||
document_ids = {doc.id for doc in documents}
|
||||
assert doc2.id in document_ids
|
||||
assert doc1.id in document_ids
|
||||
|
||||
# Verify original failures were marked as resolved
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 2
|
||||
error_doc2 = next(error for error in errors if error.document_id == doc2.id)
|
||||
error_entity = next(error for error in errors if error.entity_id == entity_id)
|
||||
|
||||
assert error_doc2.is_resolved
|
||||
assert error_entity.is_resolved
|
||||
|
||||
|
||||
def test_mock_connector_checkpoint_recovery(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that checkpointing works correctly when an unhandled exception occurs
|
||||
and that subsequent runs pick up from the last successful checkpoint."""
|
||||
# Create test documents
|
||||
# Create 100 docs for first batch, this is needed to get past the
|
||||
# `_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT` logic in `get_latest_valid_checkpoint`.
|
||||
docs_batch_1 = [create_test_document() for _ in range(100)]
|
||||
doc2 = create_test_document()
|
||||
doc3 = create_test_document()
|
||||
|
||||
# Set up mock server behavior for initial run:
|
||||
# - First yield: 100 docs with checkpoint1
|
||||
# - Second yield: doc2 with checkpoint2
|
||||
# - Third yield: unhandled exception
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [doc.model_dump(mode="json") for doc in docs_batch_1],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=True
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
},
|
||||
{
|
||||
"documents": [doc2.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=True
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
},
|
||||
{
|
||||
"documents": [],
|
||||
# should never hit this, unhandled exception happens first
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
"unhandled_exception": "Simulated unhandled error",
|
||||
},
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create CC Pair and run initial indexing attempt
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-checkpoint-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for first index attempt to complete
|
||||
initial_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=initial_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate status
|
||||
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=initial_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.FAILED
|
||||
|
||||
# Verify initial state: both docs should be indexed
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 101 # 100 docs from first batch + doc2
|
||||
document_ids = {doc.id for doc in documents}
|
||||
assert doc2.id in document_ids
|
||||
assert all(doc.id in document_ids for doc in docs_batch_1)
|
||||
|
||||
# Get the checkpoints that were sent to the mock server
|
||||
response = mock_server_client.get("/get-checkpoints")
|
||||
assert response.status_code == 200
|
||||
initial_checkpoints = response.json()
|
||||
|
||||
# Verify we got the expected checkpoints in order
|
||||
assert len(initial_checkpoints) > 0
|
||||
assert (
|
||||
initial_checkpoints[0]["checkpoint_content"] == {}
|
||||
) # Initial empty checkpoint
|
||||
assert initial_checkpoints[1]["checkpoint_content"] == {}
|
||||
assert initial_checkpoints[2]["checkpoint_content"] == {}
|
||||
|
||||
# Reset the mock server for the next run
|
||||
response = mock_server_client.post("/reset")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Set up mock server behavior for recovery run - should succeed fully this time
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [doc3.model_dump(mode="json")],
|
||||
"checkpoint": ConnectorCheckpoint(
|
||||
checkpoint_content={}, has_more=False
|
||||
).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Trigger another indexing attempt
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=False, user_performing_action=admin_user
|
||||
)
|
||||
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts_to_ignore=[initial_index_attempt.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=recovery_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate status
|
||||
finished_recovery_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=recovery_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_recovery_attempt.status == IndexingStatus.SUCCESS
|
||||
|
||||
# Verify results
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 102 # 100 docs from first batch + doc2 + doc3
|
||||
document_ids = {doc.id for doc in documents}
|
||||
assert doc3.id in document_ids
|
||||
assert doc2.id in document_ids
|
||||
assert all(doc.id in document_ids for doc in docs_batch_1)
|
||||
|
||||
# Get the checkpoints from the recovery run
|
||||
response = mock_server_client.get("/get-checkpoints")
|
||||
assert response.status_code == 200
|
||||
recovery_checkpoints = response.json()
|
||||
|
||||
# Verify the recovery run started from the last successful checkpoint
|
||||
assert len(recovery_checkpoints) == 1
|
||||
assert recovery_checkpoints[0]["checkpoint_content"] == {}
|
||||
@@ -11,7 +11,6 @@ from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.messages import ToolCallChunk
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.answer import Answer
|
||||
@@ -26,7 +25,6 @@ from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
@@ -37,7 +35,6 @@ from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTEN
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from tests.unit.onyx.chat.conftest import DEFAULT_SEARCH_ARGS
|
||||
from tests.unit.onyx.chat.conftest import QUERY
|
||||
|
||||
@@ -47,20 +44,6 @@ def answer_instance(
|
||||
mock_llm: LLM,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
mocker: MockerFixture,
|
||||
) -> Answer:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
|
||||
|
||||
|
||||
def _answer_fixture_impl(
|
||||
mock_llm: LLM,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
rerank_settings: RerankingDetails | None = None,
|
||||
) -> Answer:
|
||||
return Answer(
|
||||
prompt_builder=AnswerPromptBuilder(
|
||||
@@ -81,13 +64,13 @@ def _answer_fixture_impl(
|
||||
llm=mock_llm,
|
||||
fast_llm=mock_llm,
|
||||
force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None),
|
||||
search_request=SearchRequest(query=QUERY, rerank_settings=rerank_settings),
|
||||
search_request=SearchRequest(query=QUERY),
|
||||
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
||||
current_agent_message_id=0,
|
||||
)
|
||||
|
||||
|
||||
def test_basic_answer(answer_instance: Answer, mocker: MockerFixture) -> None:
|
||||
def test_basic_answer(answer_instance: Answer) -> None:
|
||||
mock_llm = cast(Mock, answer_instance.graph_config.tooling.primary_llm)
|
||||
mock_llm.stream.return_value = [
|
||||
AIMessageChunk(content="This is a "),
|
||||
@@ -380,49 +363,3 @@ def test_is_cancelled(answer_instance: Answer) -> None:
|
||||
|
||||
# Verify LLM calls
|
||||
mock_llm.stream.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gpu_enabled,is_local_model",
|
||||
[
|
||||
(True, False),
|
||||
(False, True),
|
||||
(True, True),
|
||||
(False, False),
|
||||
],
|
||||
)
|
||||
def test_no_slow_reranking(
|
||||
gpu_enabled: bool,
|
||||
is_local_model: bool,
|
||||
mock_llm: LLM,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
return_value=gpu_enabled,
|
||||
)
|
||||
rerank_settings = (
|
||||
None
|
||||
if is_local_model
|
||||
else RerankingDetails(
|
||||
rerank_model_name="test_model",
|
||||
rerank_api_url="test_url",
|
||||
rerank_api_key="test_key",
|
||||
num_rerank=10,
|
||||
rerank_provider_type=RerankerProvider.COHERE,
|
||||
)
|
||||
)
|
||||
answer_instance = _answer_fixture_impl(
|
||||
mock_llm, answer_style_config, prompt_config, rerank_settings=rerank_settings
|
||||
)
|
||||
|
||||
assert (
|
||||
answer_instance.graph_config.inputs.search_request.rerank_settings
|
||||
== rerank_settings
|
||||
)
|
||||
assert (
|
||||
answer_instance.graph_config.behavior.allow_agent_reranking == gpu_enabled
|
||||
or not is_local_model
|
||||
)
|
||||
|
||||
@@ -36,12 +36,7 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
mock_search_tool: SearchTool,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
question = config["question"]
|
||||
skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"]
|
||||
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
|
||||
def test_run_with_timeout_completes() -> None:
|
||||
"""Test that a function that completes within timeout works correctly"""
|
||||
|
||||
def quick_function(x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
result = run_with_timeout(1.0, quick_function, x=21)
|
||||
assert result == 42
|
||||
|
||||
|
||||
@pytest.mark.parametrize("slow,timeout", [(1, 0.1), (0.3, 0.2)])
|
||||
def test_run_with_timeout_raises_on_timeout(slow: float, timeout: float) -> None:
|
||||
"""Test that a function that exceeds timeout raises TimeoutError"""
|
||||
|
||||
def slow_function() -> None:
|
||||
time.sleep(slow) # Sleep for 2 seconds
|
||||
|
||||
with pytest.raises(TimeoutError) as exc_info:
|
||||
start = time.time()
|
||||
run_with_timeout(timeout, slow_function) # Set timeout to 0.1 seconds
|
||||
end = time.time()
|
||||
assert end - start >= timeout
|
||||
assert end - start < (slow + timeout) / 2
|
||||
assert f"timed out after {timeout} seconds" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
|
||||
def test_run_with_timeout_propagates_exceptions() -> None:
|
||||
"""Test that other exceptions from the function are propagated properly"""
|
||||
|
||||
def error_function() -> None:
|
||||
raise ValueError("Test error")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
run_with_timeout(1.0, error_function)
|
||||
|
||||
assert "Test error" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_run_with_timeout_with_args_and_kwargs() -> None:
|
||||
"""Test that args and kwargs are properly passed to the function"""
|
||||
|
||||
def complex_function(x: int, y: int, multiply: bool = False) -> int:
|
||||
if multiply:
|
||||
return x * y
|
||||
return x + y
|
||||
|
||||
# Test with just positional args
|
||||
result1 = run_with_timeout(1.0, complex_function, x=5, y=3)
|
||||
assert result1 == 8
|
||||
|
||||
# Test with positional and keyword args
|
||||
result2 = run_with_timeout(1.0, complex_function, x=5, y=3, multiply=True)
|
||||
assert result2 == 15
|
||||
@@ -61,7 +61,6 @@ services:
|
||||
# Other services
|
||||
- POSTGRES_HOST=relational_db
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-}
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose
|
||||
@@ -98,9 +97,6 @@ services:
|
||||
- LINEAR_CLIENT_ID=${LINEAR_CLIENT_ID:-}
|
||||
- LINEAR_CLIENT_SECRET=${LINEAR_CLIENT_SECRET:-}
|
||||
|
||||
# Demo purposes
|
||||
- MOCK_CONNECTOR_FILE_PATH=${MOCK_CONNECTOR_FILE_PATH:-}
|
||||
|
||||
# Analytics Configs
|
||||
- SENTRY_DSN=${SENTRY_DSN:-}
|
||||
|
||||
@@ -175,7 +171,6 @@ services:
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-}
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-}
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose for OAuth2 connectors
|
||||
|
||||
@@ -23,7 +23,7 @@ export default defineConfig({
|
||||
viewport: { width: 1280, height: 720 },
|
||||
storageState: "admin_auth.json",
|
||||
},
|
||||
testIgnore: ["**/codeUtils.test.ts", "**/chat/**/*.spec.ts"],
|
||||
testIgnore: ["**/codeUtils.test.ts"],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
@@ -1,141 +0,0 @@
|
||||
import { Modal } from "@/components/Modal";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { IndexAttemptError } from "./types";
|
||||
import { localizeAndPrettify } from "@/lib/time";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useState } from "react";
|
||||
import { PageSelector } from "@/components/PageSelector";
|
||||
|
||||
interface IndexAttemptErrorsModalProps {
|
||||
errors: {
|
||||
items: IndexAttemptError[];
|
||||
total_items: number;
|
||||
};
|
||||
onClose: () => void;
|
||||
onResolveAll: () => void;
|
||||
isResolvingErrors?: boolean;
|
||||
onPageChange: (page: number) => void;
|
||||
currentPage: number;
|
||||
pageSize?: number;
|
||||
}
|
||||
|
||||
const DEFAULT_PAGE_SIZE = 10;
|
||||
|
||||
export default function IndexAttemptErrorsModal({
|
||||
errors,
|
||||
onClose,
|
||||
onResolveAll,
|
||||
isResolvingErrors = false,
|
||||
onPageChange,
|
||||
currentPage,
|
||||
pageSize = DEFAULT_PAGE_SIZE,
|
||||
}: IndexAttemptErrorsModalProps) {
|
||||
const totalPages = Math.ceil(errors.total_items / pageSize);
|
||||
const hasUnresolvedErrors = errors.items.some((error) => !error.is_resolved);
|
||||
|
||||
return (
|
||||
<Modal title="Indexing Errors" onOutsideClick={onClose} width="max-w-6xl">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex flex-col gap-2">
|
||||
{isResolvingErrors ? (
|
||||
<div className="text-sm text-text-default">
|
||||
Currently attempting to resolve all errors by performing a full
|
||||
re-index. This may take some time to complete.
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="text-sm text-text-default">
|
||||
Below are the errors encountered during indexing. Each row
|
||||
represents a failed document or entity.
|
||||
</div>
|
||||
<div className="text-sm text-text-default">
|
||||
Click the button below to kick off a full re-index to try and
|
||||
resolve these errors. This full re-index may take much longer
|
||||
than a normal update.
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Time</TableHead>
|
||||
<TableHead>Document ID</TableHead>
|
||||
<TableHead className="w-1/2">Error Message</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{errors.items.map((error) => (
|
||||
<TableRow key={error.id}>
|
||||
<TableCell>{localizeAndPrettify(error.time_created)}</TableCell>
|
||||
<TableCell>
|
||||
{error.document_link ? (
|
||||
<a
|
||||
href={error.document_link}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-link hover:underline"
|
||||
>
|
||||
{error.document_id || error.entity_id || "Unknown"}
|
||||
</a>
|
||||
) : (
|
||||
error.document_id || error.entity_id || "Unknown"
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="whitespace-normal">
|
||||
{error.failure_message}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<span
|
||||
className={`px-2 py-1 rounded text-xs ${
|
||||
error.is_resolved
|
||||
? "bg-green-100 text-green-800"
|
||||
: "bg-red-100 text-red-800"
|
||||
}`}
|
||||
>
|
||||
{error.is_resolved ? "Resolved" : "Unresolved"}
|
||||
</span>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
<div className="mt-4">
|
||||
{totalPages > 1 && (
|
||||
<div className="flex-1 flex justify-center mb-2">
|
||||
<PageSelector
|
||||
totalPages={totalPages}
|
||||
currentPage={currentPage + 1}
|
||||
onPageChange={(page) => onPageChange(page - 1)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex w-full">
|
||||
<div className="flex gap-2 ml-auto">
|
||||
{hasUnresolvedErrors && !isResolvingErrors && (
|
||||
<Button
|
||||
onClick={onResolveAll}
|
||||
variant="default"
|
||||
className="ml-4 whitespace-nowrap"
|
||||
>
|
||||
Resolve All
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
@@ -34,26 +34,38 @@ import usePaginatedFetch from "@/hooks/usePaginatedFetch";
|
||||
const ITEMS_PER_PAGE = 8;
|
||||
const PAGES_PER_BATCH = 8;
|
||||
|
||||
export interface IndexingAttemptsTableProps {
|
||||
ccPair: CCPairFullInfo;
|
||||
indexAttempts: IndexAttemptSnapshot[];
|
||||
currentPage: number;
|
||||
totalPages: number;
|
||||
onPageChange: (page: number) => void;
|
||||
}
|
||||
|
||||
export function IndexingAttemptsTable({
|
||||
ccPair,
|
||||
indexAttempts,
|
||||
currentPage,
|
||||
totalPages,
|
||||
onPageChange,
|
||||
}: IndexingAttemptsTableProps) {
|
||||
export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) {
|
||||
const [indexAttemptTracePopupId, setIndexAttemptTracePopupId] = useState<
|
||||
number | null
|
||||
>(null);
|
||||
|
||||
if (!indexAttempts?.length) {
|
||||
const {
|
||||
currentPageData: pageOfIndexAttempts,
|
||||
isLoading,
|
||||
error,
|
||||
currentPage,
|
||||
totalPages,
|
||||
goToPage,
|
||||
} = usePaginatedFetch<IndexAttemptSnapshot>({
|
||||
itemsPerPage: ITEMS_PER_PAGE,
|
||||
pagesPerBatch: PAGES_PER_BATCH,
|
||||
endpoint: `${buildCCPairInfoUrl(ccPair.id)}/index-attempts`,
|
||||
});
|
||||
|
||||
if (isLoading || !pageOfIndexAttempts) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle={`Failed to fetch info on Connector with ID ${ccPair.id}`}
|
||||
errorMsg={error?.toString() || "Unknown error"}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (!pageOfIndexAttempts?.length) {
|
||||
return (
|
||||
<Callout
|
||||
className="mt-4"
|
||||
@@ -66,7 +78,7 @@ export function IndexingAttemptsTable({
|
||||
);
|
||||
}
|
||||
|
||||
const indexAttemptToDisplayTraceFor = indexAttempts?.find(
|
||||
const indexAttemptToDisplayTraceFor = pageOfIndexAttempts?.find(
|
||||
(indexAttempt) => indexAttempt.id === indexAttemptTracePopupId
|
||||
);
|
||||
|
||||
@@ -107,7 +119,7 @@ export function IndexingAttemptsTable({
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{indexAttempts.map((indexAttempt) => {
|
||||
{pageOfIndexAttempts.map((indexAttempt) => {
|
||||
const docsPerMinute =
|
||||
getDocsProcessedPerMinute(indexAttempt)?.toFixed(2);
|
||||
return (
|
||||
@@ -149,6 +161,18 @@ export function IndexingAttemptsTable({
|
||||
<TableCell>{indexAttempt.total_docs_indexed}</TableCell>
|
||||
<TableCell>
|
||||
<div>
|
||||
{indexAttempt.error_count > 0 && (
|
||||
<Link
|
||||
className="cursor-pointer my-auto"
|
||||
href={`/admin/indexing/${indexAttempt.id}`}
|
||||
>
|
||||
<Text className="flex flex-wrap text-link whitespace-normal">
|
||||
<SearchIcon />
|
||||
View Errors
|
||||
</Text>
|
||||
</Link>
|
||||
)}
|
||||
|
||||
{indexAttempt.status === "success" && (
|
||||
<Text className="flex flex-wrap whitespace-normal">
|
||||
{"-"}
|
||||
@@ -185,7 +209,7 @@ export function IndexingAttemptsTable({
|
||||
<PageSelector
|
||||
totalPages={totalPages}
|
||||
currentPage={currentPage}
|
||||
onPageChange={onPageChange}
|
||||
onPageChange={goToPage}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"use client";
|
||||
|
||||
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { runConnector } from "@/lib/connector";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import Text from "@/components/ui/text";
|
||||
import { triggerIndexing } from "./lib";
|
||||
import { mutate } from "swr";
|
||||
import { buildCCPairInfoUrl } from "./lib";
|
||||
import { useState } from "react";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
@@ -21,6 +23,26 @@ function ReIndexPopup({
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
hide: () => void;
|
||||
}) {
|
||||
async function triggerIndexing(fromBeginning: boolean) {
|
||||
const errorMsg = await runConnector(
|
||||
connectorId,
|
||||
[credentialId],
|
||||
fromBeginning
|
||||
);
|
||||
if (errorMsg) {
|
||||
setPopup({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
});
|
||||
} else {
|
||||
setPopup({
|
||||
message: "Triggered connector run",
|
||||
type: "success",
|
||||
});
|
||||
}
|
||||
mutate(buildCCPairInfoUrl(ccPairId));
|
||||
}
|
||||
|
||||
return (
|
||||
<Modal title="Run Indexing" onOutsideClick={hide}>
|
||||
<div>
|
||||
@@ -28,13 +50,7 @@ function ReIndexPopup({
|
||||
variant="submit"
|
||||
className="ml-auto"
|
||||
onClick={() => {
|
||||
triggerIndexing(
|
||||
false,
|
||||
connectorId,
|
||||
credentialId,
|
||||
ccPairId,
|
||||
setPopup
|
||||
);
|
||||
triggerIndexing(false);
|
||||
hide();
|
||||
}}
|
||||
>
|
||||
@@ -52,13 +68,7 @@ function ReIndexPopup({
|
||||
variant="submit"
|
||||
className="ml-auto"
|
||||
onClick={() => {
|
||||
triggerIndexing(
|
||||
true,
|
||||
connectorId,
|
||||
credentialId,
|
||||
ccPairId,
|
||||
setPopup
|
||||
);
|
||||
triggerIndexing(true);
|
||||
hide();
|
||||
}}
|
||||
>
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { runConnector } from "@/lib/connector";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { mutate } from "swr";
|
||||
|
||||
export function buildCCPairInfoUrl(ccPairId: string | number) {
|
||||
return `/api/manage/admin/cc-pair/${ccPairId}`;
|
||||
@@ -14,29 +11,3 @@ export function buildSimilarCredentialInfoURL(
|
||||
const base = `/api/manage/admin/similar-credentials/${source_type}`;
|
||||
return get_editable ? `${base}?get_editable=True` : base;
|
||||
}
|
||||
|
||||
export async function triggerIndexing(
|
||||
fromBeginning: boolean,
|
||||
connectorId: number,
|
||||
credentialId: number,
|
||||
ccPairId: number,
|
||||
setPopup: (popupSpec: PopupSpec | null) => void
|
||||
) {
|
||||
const errorMsg = await runConnector(
|
||||
connectorId,
|
||||
[credentialId],
|
||||
fromBeginning
|
||||
);
|
||||
if (errorMsg) {
|
||||
setPopup({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
});
|
||||
} else {
|
||||
setPopup({
|
||||
message: "Triggered connector run",
|
||||
type: "success",
|
||||
});
|
||||
}
|
||||
mutate(buildCCPairInfoUrl(ccPairId));
|
||||
}
|
||||
|
||||
@@ -25,24 +25,13 @@ import DeletionErrorStatus from "./DeletionErrorStatus";
|
||||
import { IndexingAttemptsTable } from "./IndexingAttemptsTable";
|
||||
import { ModifyStatusButtonCluster } from "./ModifyStatusButtonCluster";
|
||||
import { ReIndexButton } from "./ReIndexButton";
|
||||
import { buildCCPairInfoUrl, triggerIndexing } from "./lib";
|
||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||
import {
|
||||
CCPairFullInfo,
|
||||
ConnectorCredentialPairStatus,
|
||||
IndexAttemptError,
|
||||
PaginatedIndexAttemptErrors,
|
||||
} from "./types";
|
||||
import { buildCCPairInfoUrl } from "./lib";
|
||||
import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types";
|
||||
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import EditPropertyModal from "@/components/modals/EditPropertyModal";
|
||||
|
||||
import * as Yup from "yup";
|
||||
import { AlertCircle } from "lucide-react";
|
||||
import IndexAttemptErrorsModal from "./IndexAttemptErrorsModal";
|
||||
import usePaginatedFetch from "@/hooks/usePaginatedFetch";
|
||||
import { IndexAttemptSnapshot } from "@/lib/types";
|
||||
import { Spinner } from "@/components/Spinner";
|
||||
|
||||
// synchronize these validations with the SQLAlchemy connector class until we have a
|
||||
// centralized schema for both frontend and backend
|
||||
@@ -62,99 +51,43 @@ const PruneFrequencySchema = Yup.object().shape({
|
||||
.required("Property value is required"),
|
||||
});
|
||||
|
||||
const ITEMS_PER_PAGE = 8;
|
||||
const PAGES_PER_BATCH = 8;
|
||||
|
||||
function Main({ ccPairId }: { ccPairId: number }) {
|
||||
const router = useRouter();
|
||||
const router = useRouter(); // Initialize the router
|
||||
const {
|
||||
data: ccPair,
|
||||
isLoading: isLoadingCCPair,
|
||||
error: ccPairError,
|
||||
isLoading,
|
||||
error,
|
||||
} = useSWR<CCPairFullInfo>(
|
||||
buildCCPairInfoUrl(ccPairId),
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
);
|
||||
|
||||
const {
|
||||
currentPageData: indexAttempts,
|
||||
isLoading: isLoadingIndexAttempts,
|
||||
currentPage,
|
||||
totalPages,
|
||||
goToPage,
|
||||
} = usePaginatedFetch<IndexAttemptSnapshot>({
|
||||
itemsPerPage: ITEMS_PER_PAGE,
|
||||
pagesPerBatch: PAGES_PER_BATCH,
|
||||
endpoint: `${buildCCPairInfoUrl(ccPairId)}/index-attempts`,
|
||||
});
|
||||
|
||||
const {
|
||||
currentPageData: indexAttemptErrorsPage,
|
||||
currentPage: errorsCurrentPage,
|
||||
totalPages: errorsTotalPages,
|
||||
goToPage: goToErrorsPage,
|
||||
} = usePaginatedFetch<IndexAttemptError>({
|
||||
itemsPerPage: 10,
|
||||
pagesPerBatch: 1,
|
||||
endpoint: `/api/manage/admin/cc-pair/${ccPairId}/errors`,
|
||||
});
|
||||
|
||||
const indexAttemptErrors = indexAttemptErrorsPage
|
||||
? {
|
||||
items: indexAttemptErrorsPage,
|
||||
total_items:
|
||||
errorsCurrentPage === errorsTotalPages &&
|
||||
indexAttemptErrorsPage.length === 0
|
||||
? 0
|
||||
: errorsTotalPages * 10,
|
||||
}
|
||||
: null;
|
||||
|
||||
const [hasLoadedOnce, setHasLoadedOnce] = useState(false);
|
||||
const [editingRefreshFrequency, setEditingRefreshFrequency] = useState(false);
|
||||
const [editingPruningFrequency, setEditingPruningFrequency] = useState(false);
|
||||
const [showIndexAttemptErrors, setShowIndexAttemptErrors] = useState(false);
|
||||
const [showIsResolvingKickoffLoader, setShowIsResolvingKickoffLoader] =
|
||||
useState(false);
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const latestIndexAttempt = indexAttempts?.[0];
|
||||
const isResolvingErrors =
|
||||
(latestIndexAttempt?.status === "in_progress" ||
|
||||
latestIndexAttempt?.status === "not_started") &&
|
||||
latestIndexAttempt?.from_beginning &&
|
||||
// if there are errors in the latest index attempt, we don't want to show the loader
|
||||
!indexAttemptErrors?.items?.some(
|
||||
(error) => error.index_attempt_id === latestIndexAttempt?.id
|
||||
);
|
||||
|
||||
const finishConnectorDeletion = useCallback(() => {
|
||||
router.push("/admin/indexing/status?message=connector-deleted");
|
||||
}, [router]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoadingCCPair) {
|
||||
if (isLoading) {
|
||||
return;
|
||||
}
|
||||
if (ccPair && !ccPairError) {
|
||||
if (ccPair && !error) {
|
||||
setHasLoadedOnce(true);
|
||||
}
|
||||
|
||||
if (
|
||||
(hasLoadedOnce && (ccPairError || !ccPair)) ||
|
||||
(hasLoadedOnce && (error || !ccPair)) ||
|
||||
(ccPair?.status === ConnectorCredentialPairStatus.DELETING &&
|
||||
!ccPair.connector)
|
||||
) {
|
||||
finishConnectorDeletion();
|
||||
}
|
||||
}, [
|
||||
isLoadingCCPair,
|
||||
ccPair,
|
||||
ccPairError,
|
||||
hasLoadedOnce,
|
||||
finishConnectorDeletion,
|
||||
]);
|
||||
}, [isLoading, ccPair, error, hasLoadedOnce, finishConnectorDeletion]);
|
||||
|
||||
const handleUpdateName = async (newName: string) => {
|
||||
try {
|
||||
@@ -258,19 +191,15 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
}
|
||||
};
|
||||
|
||||
if (isLoadingCCPair || isLoadingIndexAttempts) {
|
||||
if (isLoading) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
if (!ccPair || (!hasLoadedOnce && ccPairError)) {
|
||||
if (!ccPair || (!hasLoadedOnce && error)) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle={`Failed to fetch info on Connector with ID ${ccPairId}`}
|
||||
errorMsg={
|
||||
ccPairError?.info?.detail ||
|
||||
ccPairError?.toString() ||
|
||||
"Unknown error"
|
||||
}
|
||||
errorMsg={error?.info?.detail || error?.toString() || "Unknown error"}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -290,7 +219,6 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
return (
|
||||
<>
|
||||
{popup}
|
||||
{showIsResolvingKickoffLoader && !isResolvingErrors && <Spinner />}
|
||||
|
||||
{editingRefreshFrequency && (
|
||||
<EditPropertyModal
|
||||
@@ -316,32 +244,6 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
/>
|
||||
)}
|
||||
|
||||
{showIndexAttemptErrors && indexAttemptErrors && (
|
||||
<IndexAttemptErrorsModal
|
||||
errors={indexAttemptErrors}
|
||||
onClose={() => setShowIndexAttemptErrors(false)}
|
||||
onResolveAll={async () => {
|
||||
setShowIndexAttemptErrors(false);
|
||||
setShowIsResolvingKickoffLoader(true);
|
||||
await triggerIndexing(
|
||||
true,
|
||||
ccPair.connector.id,
|
||||
ccPair.credential.id,
|
||||
ccPair.id,
|
||||
setPopup
|
||||
);
|
||||
|
||||
// show the loader for a max of 10 seconds
|
||||
setTimeout(() => {
|
||||
setShowIsResolvingKickoffLoader(false);
|
||||
}, 10000);
|
||||
}}
|
||||
isResolvingErrors={isResolvingErrors}
|
||||
onPageChange={goToErrorsPage}
|
||||
currentPage={errorsCurrentPage}
|
||||
/>
|
||||
)}
|
||||
|
||||
<BackButton
|
||||
behaviorOverride={() => router.push("/admin/indexing/status")}
|
||||
/>
|
||||
@@ -440,46 +342,13 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* NOTE: no divider / title here for `ConfigDisplay` since it is optional and we need
|
||||
to render these conditionally.*/}
|
||||
<div className="mt-6">
|
||||
<div className="flex">
|
||||
<Title>Indexing Attempts</Title>
|
||||
</div>
|
||||
{indexAttemptErrors && indexAttemptErrors.total_items > 0 && (
|
||||
<Alert className="border-alert bg-yellow-50 my-2">
|
||||
<AlertCircle className="h-4 w-4 text-yellow-700" />
|
||||
<AlertTitle className="text-yellow-950 font-semibold">
|
||||
Some documents failed to index
|
||||
</AlertTitle>
|
||||
<AlertDescription className="text-yellow-900">
|
||||
{isResolvingErrors ? (
|
||||
<span>
|
||||
<span className="text-sm text-yellow-700 animate-pulse">
|
||||
Resolving failures
|
||||
</span>
|
||||
</span>
|
||||
) : (
|
||||
<>
|
||||
We ran into some issues while processing some documents.{" "}
|
||||
<b
|
||||
className="text-link cursor-pointer"
|
||||
onClick={() => setShowIndexAttemptErrors(true)}
|
||||
>
|
||||
View details.
|
||||
</b>
|
||||
</>
|
||||
)}
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
{indexAttempts && (
|
||||
<IndexingAttemptsTable
|
||||
ccPair={ccPair}
|
||||
indexAttempts={indexAttempts}
|
||||
currentPage={currentPage}
|
||||
totalPages={totalPages}
|
||||
onPageChange={goToPage}
|
||||
/>
|
||||
)}
|
||||
<IndexingAttemptsTable ccPair={ccPair} />
|
||||
</div>
|
||||
<Separator />
|
||||
<div className="flex mt-4">
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user