mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 08:15:48 +00:00
Compare commits
1 Commits
nit_error
...
bugfix/ves
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ac3ec7575 |
@@ -1,6 +1,6 @@
|
||||
name: Run Playwright Tests
|
||||
name: Run Chromatic Tests
|
||||
concurrency:
|
||||
group: Run-Playwright-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on: push
|
||||
@@ -198,47 +198,43 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
# NOTE: Chromatic UI diff testing is currently disabled.
|
||||
# We are using Playwright for local and CI testing without visual regression checks.
|
||||
# Chromatic may be reintroduced in the future for UI diff testing if needed.
|
||||
chromatic-tests:
|
||||
name: Chromatic Tests
|
||||
|
||||
# chromatic-tests:
|
||||
# name: Chromatic Tests
|
||||
needs: playwright-tests
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
# needs: playwright-tests
|
||||
# runs-on:
|
||||
# [
|
||||
# runs-on,
|
||||
# runner=32cpu-linux-x64,
|
||||
# disk=large,
|
||||
# "run-id=${{ github.run_id }}",
|
||||
# ]
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@v4
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
# - name: Setup node
|
||||
# uses: actions/setup-node@v4
|
||||
# with:
|
||||
# node-version: 22
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
# - name: Install node dependencies
|
||||
# working-directory: ./web
|
||||
# run: npm ci
|
||||
- name: Download Playwright test results
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: test-results
|
||||
path: ./web/test-results
|
||||
|
||||
# - name: Download Playwright test results
|
||||
# uses: actions/download-artifact@v4
|
||||
# with:
|
||||
# name: test-results
|
||||
# path: ./web/test-results
|
||||
|
||||
# - name: Run Chromatic
|
||||
# uses: chromaui/action@latest
|
||||
# with:
|
||||
# playwright: true
|
||||
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
|
||||
# workingDir: ./web
|
||||
# env:
|
||||
# CHROMATIC_ARCHIVE_LOCATION: ./test-results
|
||||
- name: Run Chromatic
|
||||
uses: chromaui/action@latest
|
||||
with:
|
||||
playwright: true
|
||||
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
|
||||
workingDir: ./web
|
||||
env:
|
||||
CHROMATIC_ARCHIVE_LOCATION: ./test-results
|
||||
35
.github/workflows/pr-integration-tests.yml
vendored
35
.github/workflows/pr-integration-tests.yml
vendored
@@ -99,7 +99,7 @@ jobs:
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
DEV_MODE=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up -d
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack up -d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||
@@ -108,13 +108,12 @@ jobs:
|
||||
echo "Waiting for 3 minutes to ensure API server is ready..."
|
||||
sleep 180
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
docker run --rm --network danswer-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
@@ -144,27 +143,24 @@ jobs:
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
docker logs -f danswer-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
@@ -194,24 +190,15 @@ jobs:
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Start Mock Services
|
||||
run: |
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
docker run --rm --network danswer-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
@@ -221,8 +208,6 @@ jobs:
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
@@ -244,13 +229,13 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
@@ -264,4 +249,4 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
@@ -44,9 +44,6 @@ env:
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
# Gitbook
|
||||
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
||||
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
2
.vscode/launch.template.jsonc
vendored
2
.vscode/launch.template.jsonc
vendored
@@ -205,7 +205,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
|
||||
@@ -28,16 +28,14 @@ RUN apt-get update && \
|
||||
curl \
|
||||
zip \
|
||||
ca-certificates \
|
||||
libgnutls30 \
|
||||
libblkid1 \
|
||||
libmount1 \
|
||||
libsmartcols1 \
|
||||
libuuid1 \
|
||||
libgnutls30=3.7.9-2+deb12u3 \
|
||||
libblkid1=2.38.1-5+deb12u1 \
|
||||
libmount1=2.38.1-5+deb12u1 \
|
||||
libsmartcols1=2.38.1-5+deb12u1 \
|
||||
libuuid1=2.38.1-5+deb12u1 \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc \
|
||||
nano \
|
||||
vim && \
|
||||
gcc && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
"""Add checkpointing/failure handling
|
||||
|
||||
Revision ID: b7a7eee5aa15
|
||||
Revises: f39c5794c10a
|
||||
Create Date: 2025-01-24 15:17:36.763172
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7a7eee5aa15"
|
||||
down_revision = "f39c5794c10a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("checkpoint_pointer", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("poll_range_start", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("poll_range_end", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_index_attempt_cc_pair_settings_poll",
|
||||
"index_attempt",
|
||||
[
|
||||
"connector_credential_pair_id",
|
||||
"search_settings_id",
|
||||
"status",
|
||||
sa.text("time_updated DESC"),
|
||||
],
|
||||
)
|
||||
|
||||
# Drop the old IndexAttemptError table
|
||||
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
|
||||
op.drop_table("index_attempt_errors")
|
||||
|
||||
# Create the new version of the table
|
||||
op.create_table(
|
||||
"index_attempt_errors",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("index_attempt_id", sa.Integer(), nullable=False),
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_id", sa.String(), nullable=True),
|
||||
sa.Column("document_link", sa.String(), nullable=True),
|
||||
sa.Column("entity_id", sa.String(), nullable=True),
|
||||
sa.Column("failed_time_range_start", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("failed_time_range_end", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("failure_message", sa.Text(), nullable=False),
|
||||
sa.Column("is_resolved", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["index_attempt_id"],
|
||||
["index_attempt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_credential_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("SET lock_timeout = '5s'")
|
||||
|
||||
# try a few times to drop the table, this has been observed to fail due to other locks
|
||||
# blocking the drop
|
||||
NUM_TRIES = 10
|
||||
for i in range(NUM_TRIES):
|
||||
try:
|
||||
op.drop_table("index_attempt_errors")
|
||||
break
|
||||
except Exception as e:
|
||||
if i == NUM_TRIES - 1:
|
||||
raise e
|
||||
print(f"Error dropping table: {e}. Retrying...")
|
||||
|
||||
op.execute("SET lock_timeout = DEFAULT")
|
||||
|
||||
# Recreate the old IndexAttemptError table
|
||||
op.create_table(
|
||||
"index_attempt_errors",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
|
||||
sa.Column("batch", sa.Integer(), nullable=True),
|
||||
sa.Column("doc_summaries", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column("traceback", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["index_attempt_id"],
|
||||
["index_attempt.id"],
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"index_attempt_id",
|
||||
"index_attempt_errors",
|
||||
["time_created"],
|
||||
)
|
||||
|
||||
op.drop_index("ix_index_attempt_cc_pair_settings_poll")
|
||||
op.drop_column("index_attempt", "checkpoint_pointer")
|
||||
op.drop_column("index_attempt", "poll_range_start")
|
||||
op.drop_column("index_attempt", "poll_range_end")
|
||||
@@ -5,7 +5,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.connectors.slack.connector import SlackPollConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -17,7 +17,7 @@ logger = setup_logger()
|
||||
def _get_slack_document_ids_and_channels(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> dict[str, list[str]]:
|
||||
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
|
||||
@@ -5,14 +5,14 @@ from langgraph.graph import StateGraph
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -33,13 +33,13 @@ def basic_graph_builder() -> StateGraph:
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="choose_tool",
|
||||
action=choose_tool,
|
||||
node="llm_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="call_tool",
|
||||
action=call_tool,
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
@@ -51,12 +51,12 @@ def basic_graph_builder() -> StateGraph:
|
||||
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool")
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
|
||||
|
||||
graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END])
|
||||
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
||||
|
||||
graph.add_edge(
|
||||
start_key="call_tool",
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
|
||||
@@ -73,7 +73,7 @@ def should_continue(state: BasicState) -> str:
|
||||
# If there are no tool calls, basic graph already streamed the answer
|
||||
END
|
||||
if state.tool_choice is None
|
||||
else "call_tool"
|
||||
else "tool_call"
|
||||
)
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ if __name__ == "__main__":
|
||||
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = BasicInput(unused=True)
|
||||
input = BasicInput(_unused=True)
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
with get_session_context_manager() as db_session:
|
||||
config, _ = get_test_config(
|
||||
|
||||
@@ -17,7 +17,7 @@ from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
class BasicInput(BaseModel):
|
||||
# Langgraph needs a nonempty input, but we pass in all static
|
||||
# data through a RunnableConfig.
|
||||
unused: bool = True
|
||||
_unused: bool = True
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
@@ -9,6 +9,7 @@ class CoreState(BaseModel):
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
base_question: str = ""
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
@@ -17,4 +18,4 @@ class SubgraphCoreState(BaseModel):
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
log_messages: Annotated[list[str], add]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
@@ -12,45 +12,14 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
|
||||
SubQuestionAnswerCheckUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
binary_string_test,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import AgentLLMErrorType
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. The sub-answer will be treated as 'relevant'",
|
||||
rate_limit="LLM Rate Limit Error. The sub-answer will be treated as 'relevant'",
|
||||
general_error="General LLM Error. The sub-answer will be treated as 'relevant'",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def check_sub_answer(
|
||||
state: AnswerQuestionState, config: RunnableConfig
|
||||
) -> SubQuestionAnswerCheckUpdate:
|
||||
@@ -84,42 +53,14 @@ def check_sub_answer(
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: BaseMessage | None = None
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK,
|
||||
fast_llm.invoke,
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
|
||||
)
|
||||
)
|
||||
|
||||
quality_str: str = cast(str, response.content)
|
||||
answer_quality = binary_string_test(
|
||||
text=quality_str, positive_value=AGENT_POSITIVE_VALUE_STR
|
||||
)
|
||||
log_result = f"Answer quality: {quality_str}"
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
answer_quality = True
|
||||
log_result = agent_error.error_result
|
||||
logger.error("LLM Timeout Error - check sub answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
|
||||
answer_quality = True
|
||||
log_result = agent_error.error_result
|
||||
logger.error("LLM Rate Limit Error - check sub answer")
|
||||
quality_str: str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
answer_quality = "yes" in quality_str.lower()
|
||||
|
||||
return SubQuestionAnswerCheckUpdate(
|
||||
answer_quality=answer_quality,
|
||||
@@ -128,7 +69,7 @@ def check_sub_answer(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="check sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result=log_result,
|
||||
result=f"Answer quality: {quality_str}",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import merge_message_runs
|
||||
@@ -15,23 +16,6 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_sub_question_answer_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
dedup_sort_inference_section_list,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
LLM_ANSWER_ERROR_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
@@ -46,25 +30,12 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
rate_limit="LLM Rate Limit Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
general_error="General LLM Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def generate_sub_answer(
|
||||
state: AnswerQuestionState,
|
||||
config: RunnableConfig,
|
||||
@@ -80,17 +51,12 @@ def generate_sub_answer(
|
||||
state.verified_reranked_documents
|
||||
level, question_num = parse_question_id(state.question_id)
|
||||
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
|
||||
|
||||
context_docs = dedup_sort_inference_section_list(context_docs)
|
||||
|
||||
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
|
||||
graph_config.inputs.search_request.persona
|
||||
).contextualized_prompt
|
||||
|
||||
if len(context_docs) == 0:
|
||||
answer_str = NO_RECOVERED_DOCS
|
||||
cited_documents: list = []
|
||||
log_results = "No documents retrieved"
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
@@ -111,75 +77,43 @@ def generate_sub_answer(
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
response: list[str | list[str | dict[str, Any]]] = []
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: list[str] = []
|
||||
|
||||
def stream_sub_answer() -> list[str]:
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
|
||||
stream_sub_answer,
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate sub answer")
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate sub answer")
|
||||
response.append(content)
|
||||
|
||||
if agent_error:
|
||||
answer_str = LLM_ANSWER_ERROR_MESSAGE
|
||||
cited_documents = []
|
||||
log_results = (
|
||||
agent_error.error_result
|
||||
or "Sub-answer generation failed due to LLM error"
|
||||
)
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
logger.debug(
|
||||
f"Average dispatch time: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
)
|
||||
|
||||
else:
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
answer_citation_ids = get_answer_citation_ids(answer_str)
|
||||
cited_documents = [
|
||||
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
|
||||
]
|
||||
log_results = None
|
||||
answer_citation_ids = get_answer_citation_ids(answer_str)
|
||||
cited_documents = [
|
||||
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
|
||||
]
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
@@ -197,7 +131,7 @@ def generate_sub_answer(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="generate sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result=log_results or "",
|
||||
result="",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -42,8 +42,10 @@ class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):
|
||||
|
||||
|
||||
class SubQuestionAnsweringInput(SubgraphCoreState):
|
||||
question: str
|
||||
question_id: str
|
||||
question: str = ""
|
||||
question_id: str = (
|
||||
"" # 0_0 is original question, everything else is <level>_<question_num>.
|
||||
)
|
||||
# level 0 is original question and first decomposition, level 1 is follow up, etc
|
||||
# question_num is a unique number per original question per level.
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -25,31 +26,14 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
get_answer_generation_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_deduplicated_structured_subquestion_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
@@ -58,20 +42,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_ci
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
|
||||
)
|
||||
@@ -80,17 +56,8 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. The initial answer could not be generated.",
|
||||
rate_limit="LLM Rate Limit Error. The initial answer could not be generated.",
|
||||
general_error="General LLM Error. The initial answer could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def generate_initial_answer(
|
||||
state: SubQuestionRetrievalState,
|
||||
config: RunnableConfig,
|
||||
@@ -106,19 +73,15 @@ def generate_initial_answer(
|
||||
question = graph_config.inputs.search_request.query
|
||||
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
|
||||
|
||||
# get all documents cited in sub-questions
|
||||
structured_subquestion_docs = get_deduplicated_structured_subquestion_documents(
|
||||
state.sub_question_results
|
||||
)
|
||||
|
||||
sub_questions_cited_documents = state.cited_documents
|
||||
orig_question_retrieval_documents = state.orig_question_retrieved_documents
|
||||
|
||||
consolidated_context_docs = structured_subquestion_docs.cited_documents
|
||||
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
|
||||
counter = 0
|
||||
for original_doc_number, original_doc in enumerate(
|
||||
orig_question_retrieval_documents
|
||||
):
|
||||
if original_doc_number not in structured_subquestion_docs.cited_documents:
|
||||
if original_doc_number not in sub_questions_cited_documents:
|
||||
if (
|
||||
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
or len(consolidated_context_docs) < AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
@@ -127,18 +90,15 @@ def generate_initial_answer(
|
||||
counter += 1
|
||||
|
||||
# sort docs by their scores - though the scores refer to different questions
|
||||
relevant_docs = dedup_inference_section_list(consolidated_context_docs)
|
||||
relevant_docs = dedup_inference_sections(
|
||||
consolidated_context_docs, consolidated_context_docs
|
||||
)
|
||||
|
||||
sub_questions: list[str] = []
|
||||
|
||||
# Create the list of documents to stream out. Start with the
|
||||
# ones that wil be in the context (or, if len == 0, use docs
|
||||
# that were retrieved for the original question)
|
||||
answer_generation_documents = get_answer_generation_documents(
|
||||
relevant_docs=relevant_docs,
|
||||
context_documents=structured_subquestion_docs.context_documents,
|
||||
original_question_docs=orig_question_retrieval_documents,
|
||||
max_docs=AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER,
|
||||
streamed_documents = (
|
||||
relevant_docs
|
||||
if len(relevant_docs) > 0
|
||||
else state.orig_question_retrieved_documents[:15]
|
||||
)
|
||||
|
||||
# Use the query info from the base document retrieval
|
||||
@@ -148,13 +108,11 @@ def generate_initial_answer(
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
|
||||
relevance_list = relevance_from_docs(
|
||||
answer_generation_documents.streaming_documents
|
||||
)
|
||||
relevance_list = relevance_from_docs(relevant_docs)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=answer_generation_documents.streaming_documents,
|
||||
final_context_sections=answer_generation_documents.context_documents,
|
||||
reranked_sections=streamed_documents,
|
||||
final_context_sections=streamed_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
@@ -170,7 +128,7 @@ def generate_initial_answer(
|
||||
writer,
|
||||
)
|
||||
|
||||
if len(answer_generation_documents.context_documents) == 0:
|
||||
if len(relevant_docs) == 0:
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
@@ -234,13 +192,9 @@ def generate_initial_answer(
|
||||
|
||||
sub_questions = all_sub_questions # Replace the original assignment
|
||||
|
||||
model = (
|
||||
graph_config.tooling.fast_llm
|
||||
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
else graph_config.tooling.primary_llm
|
||||
)
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
doc_context = format_docs(answer_generation_documents.context_documents)
|
||||
doc_context = format_docs(relevant_docs)
|
||||
doc_context = trim_prompt_piece(
|
||||
config=model.config,
|
||||
prompt_piece=doc_context,
|
||||
@@ -268,92 +222,32 @@ def generate_initial_answer(
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str] = [""]
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
def stream_initial_answer() -> list[str]:
|
||||
response: list[str] = []
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
stream_initial_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate initial answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate initial answer")
|
||||
|
||||
if agent_error:
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
StreamingError(
|
||||
error=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=None,
|
||||
answer_error=AgentErrorLog(
|
||||
error_message=agent_error.error_message or "An LLM error occurred",
|
||||
error_type=agent_error.error_type,
|
||||
error_result=agent_error.error_result,
|
||||
),
|
||||
initial_agent_stats=None,
|
||||
generated_sub_questions=sub_questions,
|
||||
agent_base_end_time=None,
|
||||
agent_base_metrics=None,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate initial answer",
|
||||
node_name="generate initial answer",
|
||||
node_start_time=node_start_time,
|
||||
result=agent_error.error_result or "An LLM error occurred",
|
||||
)
|
||||
],
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
logger.debug(
|
||||
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
|
||||
@@ -10,10 +10,8 @@ from onyx.agents.agent_search.deep_search.main.states import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def validate_initial_answer(
|
||||
state: SubQuestionRetrievalState,
|
||||
) -> InitialAnswerQualityUpdate:
|
||||
@@ -27,7 +25,7 @@ def validate_initial_answer(
|
||||
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
|
||||
)
|
||||
|
||||
verdict = True # not actually required as already streamed out. Refinement will do similar
|
||||
verdict = True
|
||||
|
||||
return InitialAnswerQualityUpdate(
|
||||
initial_answer_quality_eval=verdict,
|
||||
|
||||
@@ -23,8 +23,6 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
@@ -35,34 +33,17 @@ from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH_ASSUMING_REFINEMENT,
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT,
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. Sub-questions could not be generated.",
|
||||
rate_limit="LLM Rate Limit Error. Sub-questions could not be generated.",
|
||||
general_error="General LLM Error. Sub-questions could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def decompose_orig_question(
|
||||
state: SubQuestionRetrievalState,
|
||||
config: RunnableConfig,
|
||||
@@ -104,15 +85,15 @@ def decompose_orig_question(
|
||||
]
|
||||
)
|
||||
|
||||
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH_ASSUMING_REFINEMENT.format(
|
||||
question=question, sample_doc_str=sample_doc_str, history=history
|
||||
decomposition_prompt = (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH.format(
|
||||
question=question, sample_doc_str=sample_doc_str, history=history
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
decomposition_prompt = (
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT.format(
|
||||
question=question, history=history
|
||||
)
|
||||
decomposition_prompt = INITIAL_QUESTION_DECOMPOSITION_PROMPT.format(
|
||||
question=question, history=history
|
||||
)
|
||||
|
||||
# Start decomposition
|
||||
@@ -131,44 +112,32 @@ def decompose_orig_question(
|
||||
)
|
||||
|
||||
# dispatches custom events for subquestion tokens, adding in subquestion ids.
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(msg),
|
||||
dispatch_subquestion(0, writer),
|
||||
sep_callback=dispatch_subquestion_sep(0, writer),
|
||||
)
|
||||
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=0,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
|
||||
dispatch_separated,
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(0, writer),
|
||||
sep_callback=dispatch_subquestion_sep(0, writer),
|
||||
)
|
||||
deomposition_response = merge_content(*streamed_tokens)
|
||||
|
||||
decomposition_response = merge_content(*streamed_tokens)
|
||||
# this call should only return strings. Commenting out for efficiency
|
||||
# assert [type(tok) == str for tok in streamed_tokens]
|
||||
|
||||
list_of_subqs = cast(str, decomposition_response).split("\n")
|
||||
# use no-op cast() instead of str() which runs code
|
||||
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
|
||||
list_of_subqs = cast(str, deomposition_response).split("\n")
|
||||
|
||||
initial_sub_questions = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
|
||||
log_result = f"decomposed original question into {len(initial_sub_questions)} subquestions"
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=0,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError) as e:
|
||||
logger.error("LLM Timeout Error - decompose orig question")
|
||||
raise e # fail loudly on this critical step
|
||||
except LLMRateLimitError as e:
|
||||
logger.error("LLM Rate Limit Error - decompose orig question")
|
||||
raise e
|
||||
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
|
||||
|
||||
return InitialQuestionDecompositionUpdate(
|
||||
initial_sub_questions=initial_sub_questions,
|
||||
initial_sub_questions=decomp_list,
|
||||
agent_start_time=agent_start_time,
|
||||
agent_refined_start_time=None,
|
||||
agent_refined_end_time=None,
|
||||
@@ -182,7 +151,7 @@ def decompose_orig_question(
|
||||
graph_component="initial - generate sub answers",
|
||||
node_name="decompose original question",
|
||||
node_start_time=node_start_time,
|
||||
result=log_result,
|
||||
result=f"decomposed original question into {len(decomp_list)} subquestions",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = setup_logger()
|
||||
|
||||
def route_initial_tool_choice(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> Literal["call_tool", "start_agent_search", "logging_node"]:
|
||||
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
|
||||
"""
|
||||
LangGraph edge to route to agent search.
|
||||
"""
|
||||
@@ -38,7 +38,7 @@ def route_initial_tool_choice(
|
||||
):
|
||||
return "start_agent_search"
|
||||
else:
|
||||
return "call_tool"
|
||||
return "tool_call"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
@@ -26,8 +26,8 @@ from onyx.agents.agent_search.deep_search.main.nodes.decide_refinement_need impo
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.extract_entities_terms import (
|
||||
extract_entities_terms,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.generate_validate_refined_answer import (
|
||||
generate_validate_refined_answer,
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.generate_refined_answer import (
|
||||
generate_refined_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.nodes.ingest_refined_sub_answers import (
|
||||
ingest_refined_sub_answers,
|
||||
@@ -43,14 +43,14 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
|
||||
answer_refined_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -77,13 +77,13 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
# Choose the initial tool
|
||||
graph.add_node(
|
||||
node="initial_tool_choice",
|
||||
action=choose_tool,
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
# Call the tool, if required
|
||||
graph.add_node(
|
||||
node="call_tool",
|
||||
action=call_tool,
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
# Use the tool response
|
||||
@@ -126,8 +126,8 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
|
||||
# Node to generate the refined answer
|
||||
graph.add_node(
|
||||
node="generate_validate_refined_answer",
|
||||
action=generate_validate_refined_answer,
|
||||
node="generate_refined_answer",
|
||||
action=generate_refined_answer,
|
||||
)
|
||||
|
||||
# Early node to extract the entities and terms from the initial answer,
|
||||
@@ -168,11 +168,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
graph.add_conditional_edges(
|
||||
"initial_tool_choice",
|
||||
route_initial_tool_choice,
|
||||
["call_tool", "start_agent_search", "logging_node"],
|
||||
["tool_call", "start_agent_search", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="call_tool",
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
graph.add_edge(
|
||||
@@ -215,11 +215,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
|
||||
graph.add_edge(
|
||||
start_key="ingest_refined_sub_answers",
|
||||
end_key="generate_validate_refined_answer",
|
||||
end_key="generate_refined_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_validate_refined_answer",
|
||||
start_key="generate_refined_answer",
|
||||
end_key="compare_answers",
|
||||
)
|
||||
graph.add_edge(
|
||||
@@ -252,7 +252,9 @@ if __name__ == "__main__":
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
|
||||
inputs = MainInput(log_messages=[])
|
||||
inputs = MainInput(
|
||||
base_question=graph_config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
@@ -11,53 +10,16 @@ from onyx.agents.agent_search.deep_search.main.states import (
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
binary_string_test,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out, and the answers could not be compared.",
|
||||
rate_limit="The LLM encountered a rate limit, and the answers could not be compared.",
|
||||
general_error="The LLM encountered an error, and the answers could not be compared.",
|
||||
)
|
||||
|
||||
_ANSWER_QUALITY_NOT_SUFFICIENT_MESSAGE = (
|
||||
"Answer quality is not sufficient, so stay with the initial answer."
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def compare_answers(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> InitialRefinedAnswerComparisonUpdate:
|
||||
@@ -72,78 +34,21 @@ def compare_answers(
|
||||
initial_answer = state.initial_answer
|
||||
refined_answer = state.refined_answer
|
||||
|
||||
# if answer quality is not sufficient, then stay with the initial answer
|
||||
if not state.refined_answer_quality:
|
||||
write_custom_event(
|
||||
"refined_answer_improvement",
|
||||
RefinedAnswerImprovement(
|
||||
refined_answer_improvement=False,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return InitialRefinedAnswerComparisonUpdate(
|
||||
refined_answer_improvement_eval=False,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="compare answers",
|
||||
node_start_time=node_start_time,
|
||||
result=_ANSWER_QUALITY_NOT_SUFFICIENT_MESSAGE,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
compare_answers_prompt = INITIAL_REFINED_ANSWER_COMPARISON_PROMPT.format(
|
||||
question=question, initial_answer=initial_answer, refined_answer=refined_answer
|
||||
)
|
||||
|
||||
msg = [HumanMessage(content=compare_answers_prompt)]
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
# Get the rewritten queries in a defined format
|
||||
model = graph_config.tooling.fast_llm
|
||||
resp: BaseMessage | None = None
|
||||
refined_answer_improvement: bool | None = None
|
||||
|
||||
# no need to stream this
|
||||
try:
|
||||
resp = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS,
|
||||
model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
|
||||
)
|
||||
resp = model.invoke(msg)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - compare answers")
|
||||
# continue as True in this support step
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - compare answers")
|
||||
# continue as True in this support step
|
||||
|
||||
if agent_error or resp is None:
|
||||
refined_answer_improvement = True
|
||||
if agent_error:
|
||||
log_result = agent_error.error_result
|
||||
else:
|
||||
log_result = "An answer could not be generated."
|
||||
|
||||
else:
|
||||
refined_answer_improvement = binary_string_test(
|
||||
text=cast(str, resp.content),
|
||||
positive_value=AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
log_result = f"Answer comparison: {refined_answer_improvement}"
|
||||
refined_answer_improvement = (
|
||||
isinstance(resp.content, str) and "yes" in resp.content.lower()
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
"refined_answer_improvement",
|
||||
@@ -160,7 +65,7 @@ def compare_answers(
|
||||
graph_component="main",
|
||||
node_name="compare answers",
|
||||
node_start_time=node_start_time,
|
||||
result=log_result,
|
||||
result=f"Answer comparison: {refined_answer_improvement}",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -21,18 +21,6 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
format_entity_term_extraction,
|
||||
@@ -42,35 +30,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT_W_INITIAL_SUBQUESTION_ANSWERS,
|
||||
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT,
|
||||
)
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_ANSWERED_SUBQUESTIONS_DIVIDER = "\n\n---\n\n"
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out. The sub-questions could not be generated.",
|
||||
rate_limit="The LLM encountered a rate limit. The sub-questions could not be generated.",
|
||||
general_error="The LLM encountered an error. The sub-questions could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def create_refined_sub_questions(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> RefinedQuestionDecompositionUpdate:
|
||||
@@ -107,10 +72,8 @@ def create_refined_sub_questions(
|
||||
|
||||
initial_question_answers = state.sub_question_results
|
||||
|
||||
addressed_subquestions_with_answers = [
|
||||
f"Subquestion: {x.question}\nSubanswer:\n{x.answer}"
|
||||
for x in initial_question_answers
|
||||
if x.verified_high_quality and x.answer
|
||||
addressed_question_list = [
|
||||
x.question for x in initial_question_answers if x.verified_high_quality
|
||||
]
|
||||
|
||||
failed_question_list = [
|
||||
@@ -119,14 +82,12 @@ def create_refined_sub_questions(
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REFINEMENT_QUESTION_DECOMPOSITION_PROMPT_W_INITIAL_SUBQUESTION_ANSWERS.format(
|
||||
content=REFINEMENT_QUESTION_DECOMPOSITION_PROMPT.format(
|
||||
question=question,
|
||||
history=history,
|
||||
entity_term_extraction_str=entity_term_extraction_str,
|
||||
base_answer=base_answer,
|
||||
answered_subquestions_with_answers=_ANSWERED_SUBQUESTIONS_DIVIDER.join(
|
||||
addressed_subquestions_with_answers
|
||||
),
|
||||
answered_sub_questions="\n - ".join(addressed_question_list),
|
||||
failed_sub_questions="\n - ".join(failed_question_list),
|
||||
),
|
||||
)
|
||||
@@ -135,67 +96,29 @@ def create_refined_sub_questions(
|
||||
# Grader
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
dispatch_separated,
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(1, writer),
|
||||
sep_callback=dispatch_subquestion_sep(1, writer),
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - create refined sub questions")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - create refined sub questions")
|
||||
|
||||
if agent_error:
|
||||
refined_sub_question_dict: dict[int, RefinementSubQuestion] = {}
|
||||
log_result = agent_error.error_result
|
||||
write_custom_event(
|
||||
"refined_sub_question_creation_error",
|
||||
StreamingError(
|
||||
error="Your LLM was not able to create refined sub questions in time and timed out. Please try again.",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(msg),
|
||||
dispatch_subquestion(1, writer),
|
||||
sep_callback=dispatch_subquestion_sep(1, writer),
|
||||
)
|
||||
response = merge_content(*streamed_tokens)
|
||||
|
||||
if isinstance(response, str):
|
||||
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
|
||||
else:
|
||||
response = merge_content(*streamed_tokens)
|
||||
raise ValueError("LLM response is not a string")
|
||||
|
||||
if isinstance(response, str):
|
||||
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
|
||||
else:
|
||||
raise ValueError("LLM response is not a string")
|
||||
refined_sub_question_dict = {}
|
||||
for sub_question_num, sub_question in enumerate(parsed_response):
|
||||
refined_sub_question = RefinementSubQuestion(
|
||||
sub_question=sub_question,
|
||||
sub_question_id=make_question_id(1, sub_question_num + 1),
|
||||
verified=False,
|
||||
answered=False,
|
||||
answer="",
|
||||
)
|
||||
|
||||
refined_sub_question_dict = {}
|
||||
for sub_question_num, sub_question in enumerate(parsed_response):
|
||||
refined_sub_question = RefinementSubQuestion(
|
||||
sub_question=sub_question,
|
||||
sub_question_id=make_question_id(1, sub_question_num + 1),
|
||||
verified=False,
|
||||
answered=False,
|
||||
answer="",
|
||||
)
|
||||
|
||||
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
|
||||
|
||||
log_result = f"Created {len(refined_sub_question_dict)} refined sub questions"
|
||||
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
|
||||
|
||||
return RefinedQuestionDecompositionUpdate(
|
||||
refined_sub_questions=refined_sub_question_dict,
|
||||
@@ -205,7 +128,7 @@ def create_refined_sub_questions(
|
||||
graph_component="main",
|
||||
node_name="create refined sub questions",
|
||||
node_start_time=node_start_time,
|
||||
result=log_result,
|
||||
result=f"Created {len(refined_sub_question_dict)} refined sub questions",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -11,10 +11,8 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def decide_refinement_need(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> RequireRefinemenEvalUpdate:
|
||||
@@ -28,19 +26,6 @@ def decide_refinement_need(
|
||||
|
||||
decision = True # TODO: just for current testing purposes
|
||||
|
||||
if state.answer_error:
|
||||
return RequireRefinemenEvalUpdate(
|
||||
require_refined_answer_eval=False,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="decide refinement need",
|
||||
node_start_time=node_start_time,
|
||||
result="Timeout Error",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
log_messages = [
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
|
||||
@@ -21,22 +21,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def extract_entities_terms(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> EntityTermExtractionUpdate:
|
||||
@@ -90,42 +79,29 @@ def extract_entities_terms(
|
||||
]
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
# Grader
|
||||
llm_response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
entity_extraction_result = EntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
entity_extraction_result = EntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
logger.error("LLM Timeout Error - extract entities terms")
|
||||
except ValueError:
|
||||
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
)
|
||||
|
||||
except LLMRateLimitError:
|
||||
logger.error("LLM Rate Limit Error - extract entities terms")
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(
|
||||
entities=[],
|
||||
relationships=[],
|
||||
terms=[],
|
||||
),
|
||||
)
|
||||
|
||||
return EntityTermExtractionUpdate(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -10,49 +11,27 @@ from onyx.agents.agent_search.deep_search.main.models import (
|
||||
AgentRefinedMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
RefinedAnswerUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
binary_string_test_after_answer_separator,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
get_prompt_enrichment_components,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
get_answer_generation_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import AGENT_ANSWER_SEPARATOR
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InferenceSection
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_deduplicated_structured_subquestion_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
@@ -64,58 +43,26 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
REFINED_ANSWER_VALIDATION_PROMPT,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE_REFINED,
|
||||
)
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out. The refined answer could not be generated.",
|
||||
rate_limit="The LLM encountered a rate limit. The refined answer could not be generated.",
|
||||
general_error="The LLM encountered an error. The refined answer could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def generate_validate_refined_answer(
|
||||
def generate_refined_answer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> RefinedAnswerUpdate:
|
||||
"""
|
||||
LangGraph node to generate the refined answer and validate it.
|
||||
LangGraph node to generate the refined answer.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
@@ -129,24 +76,19 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
|
||||
verified_reranked_documents = state.verified_reranked_documents
|
||||
|
||||
# get all documents cited in sub-questions
|
||||
structured_subquestion_docs = get_deduplicated_structured_subquestion_documents(
|
||||
state.sub_question_results
|
||||
)
|
||||
|
||||
sub_questions_cited_documents = state.cited_documents
|
||||
original_question_verified_documents = (
|
||||
state.orig_question_verified_reranked_documents
|
||||
)
|
||||
original_question_retrieved_documents = state.orig_question_retrieved_documents
|
||||
|
||||
consolidated_context_docs = structured_subquestion_docs.cited_documents
|
||||
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
|
||||
|
||||
counter = 0
|
||||
for original_doc_number, original_doc in enumerate(
|
||||
original_question_verified_documents
|
||||
):
|
||||
if original_doc_number not in structured_subquestion_docs.cited_documents:
|
||||
if original_doc_number not in sub_questions_cited_documents:
|
||||
if (
|
||||
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
or len(consolidated_context_docs)
|
||||
@@ -157,16 +99,14 @@ def generate_validate_refined_answer(
|
||||
counter += 1
|
||||
|
||||
# sort docs by their scores - though the scores refer to different questions
|
||||
relevant_docs = dedup_inference_section_list(consolidated_context_docs)
|
||||
relevant_docs = dedup_inference_sections(
|
||||
consolidated_context_docs, consolidated_context_docs
|
||||
)
|
||||
|
||||
# Create the list of documents to stream out. Start with the
|
||||
# ones that wil be in the context (or, if len == 0, use docs
|
||||
# that were retrieved for the original question)
|
||||
answer_generation_documents = get_answer_generation_documents(
|
||||
relevant_docs=relevant_docs,
|
||||
context_documents=structured_subquestion_docs.context_documents,
|
||||
original_question_docs=original_question_retrieved_documents,
|
||||
max_docs=AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER,
|
||||
streaming_docs = (
|
||||
relevant_docs
|
||||
if len(relevant_docs) > 0
|
||||
else original_question_retrieved_documents[:15]
|
||||
)
|
||||
|
||||
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
|
||||
@@ -174,13 +114,11 @@ def generate_validate_refined_answer(
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
# stream refined answer docs, or original question docs if no relevant docs are found
|
||||
relevance_list = relevance_from_docs(
|
||||
answer_generation_documents.streaming_documents
|
||||
)
|
||||
relevance_list = relevance_from_docs(relevant_docs)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=answer_generation_documents.streaming_documents,
|
||||
final_context_sections=answer_generation_documents.context_documents,
|
||||
reranked_sections=streaming_docs,
|
||||
final_context_sections=streaming_docs,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
@@ -260,13 +198,8 @@ def generate_validate_refined_answer(
|
||||
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
|
||||
)
|
||||
|
||||
model = (
|
||||
graph_config.tooling.fast_llm
|
||||
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
else graph_config.tooling.primary_llm
|
||||
)
|
||||
|
||||
relevant_docs_str = format_docs(answer_generation_documents.context_documents)
|
||||
model = graph_config.tooling.fast_llm
|
||||
relevant_docs_str = format_docs(relevant_docs)
|
||||
relevant_docs_str = trim_prompt_piece(
|
||||
model.config,
|
||||
relevant_docs_str,
|
||||
@@ -296,89 +229,30 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str] = [""]
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
def stream_refined_answer() -> list[str]:
|
||||
for message in model.stream(
|
||||
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=1,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
return streamed_tokens
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
|
||||
stream_refined_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate refined answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate refined answer")
|
||||
|
||||
if agent_error:
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
StreamingError(
|
||||
error=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=1,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return RefinedAnswerUpdate(
|
||||
refined_answer=None,
|
||||
refined_answer_quality=False, # TODO: replace this with the actual check value
|
||||
refined_agent_stats=None,
|
||||
agent_refined_end_time=None,
|
||||
agent_refined_metrics=AgentRefinedMetrics(
|
||||
refined_doc_boost_factor=0.0,
|
||||
refined_question_boost_factor=0.0,
|
||||
duration_s=None,
|
||||
),
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="generate refined answer",
|
||||
node_start_time=node_start_time,
|
||||
result=agent_error.error_result or "An LLM error occurred",
|
||||
)
|
||||
],
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
logger.debug(
|
||||
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
@@ -387,47 +261,54 @@ def generate_validate_refined_answer(
|
||||
response = merge_content(*streamed_tokens)
|
||||
answer = cast(str, response)
|
||||
|
||||
# run a validation step for the refined answer only
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REFINED_ANSWER_VALIDATION_PROMPT.format(
|
||||
question=question,
|
||||
history=prompt_enrichment_components.history,
|
||||
answered_sub_questions=sub_question_answer_str,
|
||||
relevant_docs=relevant_docs_str,
|
||||
proposed_answer=answer,
|
||||
persona_specification=persona_contextualized_prompt,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
validation_model = graph_config.tooling.fast_llm
|
||||
try:
|
||||
validation_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
validation_model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
)
|
||||
refined_answer_quality = binary_string_test_after_answer_separator(
|
||||
text=cast(str, validation_response.content),
|
||||
positive_value=AGENT_POSITIVE_VALUE_STR,
|
||||
separator=AGENT_ANSWER_SEPARATOR,
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
refined_answer_quality = True
|
||||
logger.error("LLM Timeout Error - validate refined answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
refined_answer_quality = True
|
||||
logger.error("LLM Rate Limit Error - validate refined answer")
|
||||
|
||||
refined_agent_stats = RefinedAgentStats(
|
||||
revision_doc_efficiency=refined_doc_effectiveness,
|
||||
revision_question_efficiency=revision_question_efficiency,
|
||||
)
|
||||
|
||||
logger.debug(f"\n\n---INITIAL ANSWER ---\n\n Answer:\n Agent: {initial_answer}")
|
||||
logger.debug("-" * 10)
|
||||
logger.debug(f"\n\n---REVISED AGENT ANSWER ---\n\n Answer:\n Agent: {answer}")
|
||||
|
||||
logger.debug("-" * 100)
|
||||
|
||||
if state.initial_agent_stats:
|
||||
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio", "--"
|
||||
)
|
||||
initial_support_boost_factor = (
|
||||
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
|
||||
)
|
||||
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
|
||||
"num_verified_documents", "--"
|
||||
)
|
||||
initial_verified_docs_avg_score = (
|
||||
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
|
||||
)
|
||||
initial_sub_questions_verified_docs = (
|
||||
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
|
||||
)
|
||||
|
||||
logger.debug("INITIAL AGENT STATS")
|
||||
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
|
||||
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
|
||||
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
|
||||
logger.debug(
|
||||
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
|
||||
)
|
||||
if refined_agent_stats:
|
||||
logger.debug("-" * 10)
|
||||
logger.debug("REFINED AGENT STATS")
|
||||
logger.debug(
|
||||
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
|
||||
)
|
||||
|
||||
agent_refined_end_time = datetime.now()
|
||||
if state.agent_refined_start_time:
|
||||
agent_refined_duration = (
|
||||
@@ -444,7 +325,7 @@ def generate_validate_refined_answer(
|
||||
|
||||
return RefinedAnswerUpdate(
|
||||
refined_answer=answer,
|
||||
refined_answer_quality=refined_answer_quality,
|
||||
refined_answer_quality=True, # TODO: replace this with the actual check value
|
||||
refined_agent_stats=refined_agent_stats,
|
||||
agent_refined_end_time=agent_refined_end_time,
|
||||
agent_refined_metrics=agent_refined_metrics,
|
||||
@@ -17,7 +17,6 @@ from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
@@ -77,7 +76,6 @@ class InitialAnswerUpdate(LoggerUpdate):
|
||||
"""
|
||||
|
||||
initial_answer: str | None = None
|
||||
answer_error: AgentErrorLog | None = None
|
||||
initial_agent_stats: InitialAgentResultStats | None = None
|
||||
generated_sub_questions: list[str] = []
|
||||
agent_base_end_time: datetime | None = None
|
||||
@@ -90,7 +88,6 @@ class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
|
||||
"""
|
||||
|
||||
refined_answer: str | None = None
|
||||
answer_error: AgentErrorLog | None = None
|
||||
refined_agent_stats: RefinedAgentStats | None = None
|
||||
refined_answer_quality: bool = False
|
||||
|
||||
|
||||
@@ -16,46 +16,16 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
QUERY_REWRITING_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="Query rewriting failed due to LLM timeout - the original question will be used.",
|
||||
rate_limit="Query rewriting failed due to LLM rate limit - the original question will be used.",
|
||||
general_error="Query rewriting failed due to LLM error - the original question will be used.",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def expand_queries(
|
||||
state: ExpandedRetrievalInput,
|
||||
config: RunnableConfig,
|
||||
@@ -71,7 +41,7 @@ def expand_queries(
|
||||
node_start_time = datetime.now()
|
||||
question = state.question
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
llm = graph_config.tooling.fast_llm
|
||||
sub_question_id = state.sub_question_id
|
||||
if sub_question_id is None:
|
||||
level, question_num = 0, 0
|
||||
@@ -84,45 +54,13 @@ def expand_queries(
|
||||
)
|
||||
]
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
llm_response_list: list[BaseMessage_Content] = []
|
||||
llm_response = ""
|
||||
rewritten_queries = []
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(prompt=msg), dispatch_subquery(level, question_num, writer)
|
||||
)
|
||||
|
||||
try:
|
||||
llm_response_list = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION,
|
||||
dispatch_separated,
|
||||
model.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
),
|
||||
dispatch_subquery(level, question_num, writer),
|
||||
)
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[
|
||||
0
|
||||
].content
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - expand queries")
|
||||
log_result = agent_error.error_result
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - expand queries")
|
||||
log_result = agent_error.error_result
|
||||
# use subquestion as query if query generation fails
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=rewritten_queries,
|
||||
@@ -131,7 +69,7 @@ def expand_queries(
|
||||
graph_component="shared - expanded retrieval",
|
||||
node_name="expand queries",
|
||||
node_start_time=node_start_time,
|
||||
result=log_result,
|
||||
result=f"Number of expanded queries: {len(rewritten_queries)}",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -26,10 +26,8 @@ from onyx.context.search.postprocessing.postprocessing import rerank_sections
|
||||
from onyx.context.search.postprocessing.postprocessing import should_rerank
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def rerank_documents(
|
||||
state: ExpandedRetrievalState, config: RunnableConfig
|
||||
) -> DocRerankingUpdate:
|
||||
@@ -55,7 +53,6 @@ def rerank_documents(
|
||||
|
||||
# Note that these are passed in values from the API and are overrides which are typically None
|
||||
rerank_settings = graph_config.inputs.search_request.rerank_settings
|
||||
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
|
||||
|
||||
if rerank_settings is None:
|
||||
with get_session_context_manager() as db_session:
|
||||
@@ -63,31 +60,23 @@ def rerank_documents(
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
# Initial default: no reranking. Will be overwritten below if reranking is warranted
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if should_rerank(rerank_settings) and len(verified_documents) > 0:
|
||||
if len(verified_documents) > 1:
|
||||
if not allow_agent_reranking:
|
||||
logger.info("Use of local rerank model without GPU, skipping reranking")
|
||||
# No reranking, stay with verified_documents as default
|
||||
|
||||
else:
|
||||
# Reranking is warranted, use the rerank_sections functon
|
||||
reranked_documents = rerank_sections(
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
reranked_documents = rerank_sections(
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{len(verified_documents)} verified document(s) found, skipping reranking"
|
||||
)
|
||||
# No reranking, stay with verified_documents as default
|
||||
reranked_documents = verified_documents
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
# No reranking, stay with verified_documents as default
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if AGENT_RERANKING_STATS:
|
||||
fit_scores = get_fit_scores(verified_documents, reranked_documents)
|
||||
else:
|
||||
|
||||
@@ -28,10 +28,8 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def retrieve_documents(
|
||||
state: RetrievalInput, config: RunnableConfig
|
||||
) -> DocRetrievalUpdate:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
@@ -12,40 +10,14 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
||||
DocVerificationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
binary_string_test,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
DOCUMENT_VERIFICATION_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out. The document could not be verified. The document will be treated as 'relevant'",
|
||||
rate_limit="The LLM encountered a rate limit. The document could not be verified. The document will be treated as 'relevant'",
|
||||
general_error="The LLM encountered an error. The document could not be verified. The document will be treated as 'relevant'",
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def verify_documents(
|
||||
state: DocVerificationInput, config: RunnableConfig
|
||||
) -> DocVerificationUpdate:
|
||||
@@ -54,14 +26,12 @@ def verify_documents(
|
||||
|
||||
Args:
|
||||
state (DocVerificationInput): The current state
|
||||
config (RunnableConfig): Configuration containing AgentSearchConfig
|
||||
config (RunnableConfig): Configuration containing ProSearchConfig
|
||||
|
||||
Updates:
|
||||
verified_documents: list[InferenceSection]
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state.question
|
||||
retrieved_document_to_verify = state.retrieved_document_to_verify
|
||||
document_content = retrieved_document_to_verify.combined_content
|
||||
@@ -81,43 +51,12 @@ def verify_documents(
|
||||
)
|
||||
]
|
||||
|
||||
response: BaseMessage | None = None
|
||||
response = fast_llm.invoke(msg)
|
||||
|
||||
verified_documents = [
|
||||
retrieved_document_to_verify
|
||||
] # default is to treat document as relevant
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
if not binary_string_test(
|
||||
text=response.content, positive_value=AGENT_POSITIVE_VALUE_STR
|
||||
):
|
||||
verified_documents = []
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
# In this case, we decide to continue and don't raise an error, as
|
||||
# little harm in letting some docs through that are less relevant.
|
||||
logger.error("LLM Timeout Error - verify documents")
|
||||
|
||||
except LLMRateLimitError:
|
||||
# In this case, we decide to continue and don't raise an error, as
|
||||
# little harm in letting some docs through that are less relevant.
|
||||
logger.error("LLM Rate Limit Error - verify documents")
|
||||
verified_documents = []
|
||||
if isinstance(response.content, str) and "yes" in response.content.lower():
|
||||
verified_documents.append(retrieved_document_to_verify)
|
||||
|
||||
return DocVerificationUpdate(
|
||||
verified_documents=verified_documents,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="shared - expanded retrieval",
|
||||
node_name="verify documents",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -21,13 +21,9 @@ from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class ExpandedRetrievalInput(SubgraphCoreState):
|
||||
# exception from 'no default value'for LangGraph input states
|
||||
# Here, sub_question_id default None implies usage for the
|
||||
# original question. This is sometimes needed for nested sub-graphs
|
||||
|
||||
question: str = ""
|
||||
base_search: bool = False
|
||||
sub_question_id: str | None = None
|
||||
question: str
|
||||
base_search: bool
|
||||
|
||||
|
||||
## Update/Return States
|
||||
@@ -38,7 +34,7 @@ class QueryExpansionUpdate(LoggerUpdate, BaseModel):
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
class DocVerificationUpdate(LoggerUpdate, BaseModel):
|
||||
class DocVerificationUpdate(BaseModel):
|
||||
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
|
||||
|
||||
|
||||
@@ -92,4 +88,4 @@ class DocVerificationInput(ExpandedRetrievalInput):
|
||||
|
||||
|
||||
class RetrievalInput(ExpandedRetrievalInput):
|
||||
query_to_retrieve: str
|
||||
query_to_retrieve: str = ""
|
||||
|
||||
@@ -67,7 +67,6 @@ class GraphSearchConfig(BaseModel):
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
allow_agent_reranking: bool = False
|
||||
|
||||
|
||||
class GraphConfig(BaseModel):
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = setup_logger()
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
def choose_tool(
|
||||
def llm_tool_choice(
|
||||
state: ToolChoiceState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
@@ -28,7 +28,7 @@ def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None:
|
||||
write_custom_event("basic_response", packet, writer)
|
||||
|
||||
|
||||
def call_tool(
|
||||
def tool_call(
|
||||
state: ToolChoiceUpdate,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
@@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||
main_graph_builder as main_graph_builder_a,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
MainInput as MainInput,
|
||||
MainInput as MainInput_a,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
@@ -21,7 +21,6 @@ from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
@@ -34,7 +33,6 @@ from onyx.llm.factory import get_default_llms
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
||||
@@ -74,15 +72,13 @@ def _parse_agent_event(
|
||||
return cast(AnswerPacket, event["data"])
|
||||
elif event["name"] == "refined_answer_improvement":
|
||||
return cast(RefinedAnswerImprovement, event["data"])
|
||||
elif event["name"] == "refined_sub_question_creation_error":
|
||||
return cast(StreamingError, event["data"])
|
||||
return None
|
||||
|
||||
|
||||
def manage_sync_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
graph_input: BasicInput | MainInput,
|
||||
graph_input: BasicInput | MainInput_a,
|
||||
) -> Iterable[StreamEvent]:
|
||||
message_id = config.persistence.message_id if config.persistence else None
|
||||
for event in compiled_graph.stream(
|
||||
@@ -96,7 +92,7 @@ def manage_sync_streaming(
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
input: BasicInput | MainInput,
|
||||
input: BasicInput | MainInput_a,
|
||||
) -> AnswerStream:
|
||||
config.behavior.perform_initial_search_decomposition = (
|
||||
INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
@@ -127,7 +123,9 @@ def run_main_graph(
|
||||
) -> AnswerStream:
|
||||
compiled_graph = load_compiled_graph()
|
||||
|
||||
input = MainInput(log_messages=[])
|
||||
input = MainInput_a(
|
||||
base_question=config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
# Agent search is not a Tool per se, but this is helpful for the frontend
|
||||
yield ToolCallKickoff(
|
||||
@@ -142,7 +140,7 @@ def run_basic_graph(
|
||||
) -> AnswerStream:
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = BasicInput(unused=True)
|
||||
input = BasicInput()
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
@@ -174,7 +172,9 @@ if __name__ == "__main__":
|
||||
# search_request.persona = get_persona_by_id(1, None, db_session)
|
||||
# config.perform_initial_search_path_decision = False
|
||||
config.behavior.perform_initial_search_decomposition = True
|
||||
input = MainInput(log_messages=[])
|
||||
input = MainInput_a(
|
||||
base_question=config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
tool_responses: list = []
|
||||
for output in run_graph(compiled_graph, config, input):
|
||||
|
||||
@@ -7,7 +7,6 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
AgentPromptEnrichmentComponents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
@@ -41,7 +40,13 @@ def build_sub_question_answer_prompt(
|
||||
|
||||
date_str = build_date_time_string()
|
||||
|
||||
docs_str = format_docs(docs)
|
||||
# TODO: This should include document metadata and title
|
||||
docs_format_list = [
|
||||
f"Document Number: [D{doc_num + 1}]\nContent: {doc.combined_content}\n\n"
|
||||
for doc_num, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
docs_str = "\n\n".join(docs_format_list)
|
||||
|
||||
docs_str = trim_prompt_piece(
|
||||
config,
|
||||
@@ -145,38 +150,3 @@ def get_prompt_enrichment_components(
|
||||
history=history,
|
||||
date_str=date_str,
|
||||
)
|
||||
|
||||
|
||||
def binary_string_test(text: str, positive_value: str = "yes") -> bool:
|
||||
"""
|
||||
Tests if a string contains a positive value (case-insensitive).
|
||||
|
||||
Args:
|
||||
text: The string to test
|
||||
positive_value: The value to look for (defaults to "yes")
|
||||
|
||||
Returns:
|
||||
True if the positive value is found in the text
|
||||
"""
|
||||
return positive_value.lower() in text.lower()
|
||||
|
||||
|
||||
def binary_string_test_after_answer_separator(
|
||||
text: str, positive_value: str = "yes", separator: str = "Answer:"
|
||||
) -> bool:
|
||||
"""
|
||||
Tests if a string contains a positive value (case-insensitive).
|
||||
|
||||
Args:
|
||||
text: The string to test
|
||||
positive_value: The value to look for (defaults to "yes")
|
||||
|
||||
Returns:
|
||||
True if the positive value is found in the text
|
||||
"""
|
||||
|
||||
if separator not in text:
|
||||
return False
|
||||
relevant_text = text.split(f"{separator}")[-1]
|
||||
|
||||
return binary_string_test(relevant_text, positive_value)
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import numpy as np
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AnswerGenerationDocuments
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -100,106 +96,3 @@ def get_fit_scores(
|
||||
)
|
||||
|
||||
return fit_eval
|
||||
|
||||
|
||||
def get_answer_generation_documents(
|
||||
relevant_docs: list[InferenceSection],
|
||||
context_documents: list[InferenceSection],
|
||||
original_question_docs: list[InferenceSection],
|
||||
max_docs: int,
|
||||
) -> AnswerGenerationDocuments:
|
||||
"""
|
||||
Create a deduplicated list of documents to stream, prioritizing relevant docs.
|
||||
|
||||
Args:
|
||||
relevant_docs: Primary documents to include
|
||||
context_documents: Additional context documents to append
|
||||
original_question_docs: Original question documents to append
|
||||
max_docs: Maximum number of documents to return
|
||||
|
||||
Returns:
|
||||
List of deduplicated documents, limited to max_docs
|
||||
"""
|
||||
# get relevant_doc ids
|
||||
relevant_doc_ids = [doc.center_chunk.document_id for doc in relevant_docs]
|
||||
|
||||
# Start with relevant docs or fallback to original question docs
|
||||
streaming_documents = relevant_docs.copy()
|
||||
|
||||
# Use a set for O(1) lookups of document IDs
|
||||
seen_doc_ids = {doc.center_chunk.document_id for doc in streaming_documents}
|
||||
|
||||
# Combine additional documents to check in one iteration
|
||||
additional_docs = context_documents + original_question_docs
|
||||
for doc_idx, doc in enumerate(additional_docs):
|
||||
doc_id = doc.center_chunk.document_id
|
||||
if doc_id not in seen_doc_ids:
|
||||
streaming_documents.append(doc)
|
||||
seen_doc_ids.add(doc_id)
|
||||
|
||||
streaming_documents = dedup_inference_section_list(streaming_documents)
|
||||
|
||||
relevant_streaming_docs = [
|
||||
doc
|
||||
for doc in streaming_documents
|
||||
if doc.center_chunk.document_id in relevant_doc_ids
|
||||
]
|
||||
relevant_streaming_docs = dedup_sort_inference_section_list(relevant_streaming_docs)
|
||||
|
||||
additional_streaming_docs = [
|
||||
doc
|
||||
for doc in streaming_documents
|
||||
if doc.center_chunk.document_id not in relevant_doc_ids
|
||||
]
|
||||
additional_streaming_docs = dedup_sort_inference_section_list(
|
||||
additional_streaming_docs
|
||||
)
|
||||
|
||||
for doc in additional_streaming_docs:
|
||||
if doc.center_chunk.score:
|
||||
doc.center_chunk.score += -2.0
|
||||
else:
|
||||
doc.center_chunk.score = -2.0
|
||||
|
||||
sorted_streaming_documents = relevant_streaming_docs + additional_streaming_docs
|
||||
|
||||
return AnswerGenerationDocuments(
|
||||
streaming_documents=sorted_streaming_documents[:max_docs],
|
||||
context_documents=relevant_streaming_docs[:max_docs],
|
||||
)
|
||||
|
||||
|
||||
def dedup_sort_inference_section_list(
|
||||
sections: list[InferenceSection],
|
||||
) -> list[InferenceSection]:
|
||||
"""Deduplicates InferenceSections by document_id and sorts by score.
|
||||
|
||||
Args:
|
||||
sections: List of InferenceSections to deduplicate and sort
|
||||
|
||||
Returns:
|
||||
Deduplicated list of InferenceSections sorted by score in descending order
|
||||
"""
|
||||
# dedupe/merge with existing framework
|
||||
sections = dedup_inference_section_list(sections)
|
||||
|
||||
# Use dict to deduplicate by document_id, keeping highest scored version
|
||||
unique_sections: dict[str, InferenceSection] = {}
|
||||
for section in sections:
|
||||
doc_id = section.center_chunk.document_id
|
||||
if doc_id not in unique_sections:
|
||||
unique_sections[doc_id] = section
|
||||
continue
|
||||
|
||||
# Keep version with higher score
|
||||
existing_score = unique_sections[doc_id].center_chunk.score or 0
|
||||
new_score = section.center_chunk.score or 0
|
||||
if new_score > existing_score:
|
||||
unique_sections[doc_id] = section
|
||||
|
||||
# Sort by score in descending order, handling None scores
|
||||
sorted_sections = sorted(
|
||||
unique_sections.values(), key=lambda x: x.center_chunk.score or 0, reverse=True
|
||||
)
|
||||
|
||||
return sorted_sections
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
AGENT_LLM_TIMEOUT_MESSAGE = "The agent timed out. Please try again."
|
||||
AGENT_LLM_ERROR_MESSAGE = "The agent encountered an error. Please try again."
|
||||
AGENT_LLM_RATELIMIT_MESSAGE = (
|
||||
"The agent encountered a rate limit error. Please try again."
|
||||
)
|
||||
LLM_ANSWER_ERROR_MESSAGE = "The question was not answered due to an LLM error."
|
||||
|
||||
AGENT_POSITIVE_VALUE_STR = "yes"
|
||||
AGENT_NEGATIVE_VALUE_STR = "no"
|
||||
|
||||
AGENT_ANSWER_SEPARATOR = "Answer:"
|
||||
|
||||
|
||||
class AgentLLMErrorType(str, Enum):
|
||||
TIMEOUT = "timeout"
|
||||
RATE_LIMIT = "rate_limit"
|
||||
GENERAL_ERROR = "general_error"
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
@@ -58,12 +56,6 @@ class InitialAgentResultStats(BaseModel):
|
||||
agent_effectiveness: dict[str, float | int | None]
|
||||
|
||||
|
||||
class AgentErrorLog(BaseModel):
|
||||
error_message: str
|
||||
error_type: str
|
||||
error_result: str
|
||||
|
||||
|
||||
class RefinedAgentStats(BaseModel):
|
||||
revision_doc_efficiency: float | None
|
||||
revision_question_efficiency: float | None
|
||||
@@ -118,11 +110,6 @@ class SubQuestionAnswerResults(BaseModel):
|
||||
sub_question_retrieval_stats: AgentChunkRetrievalStats
|
||||
|
||||
|
||||
class StructuredSubquestionDocuments(BaseModel):
|
||||
cited_documents: list[InferenceSection]
|
||||
context_documents: list[InferenceSection]
|
||||
|
||||
|
||||
class CombinedAgentMetrics(BaseModel):
|
||||
timings: AgentTimings
|
||||
base_metrics: AgentBaseMetrics | None
|
||||
@@ -139,17 +126,3 @@ class AgentPromptEnrichmentComponents(BaseModel):
|
||||
persona_prompts: PersonaPromptExpressions
|
||||
history: str
|
||||
date_str: str
|
||||
|
||||
|
||||
class LLMNodeErrorStrings(BaseModel):
|
||||
timeout: str = "LLM Timeout Error"
|
||||
rate_limit: str = "LLM Rate Limit Error"
|
||||
general_error: str = "General LLM Error"
|
||||
|
||||
|
||||
class AnswerGenerationDocuments(BaseModel):
|
||||
streaming_documents: list[InferenceSection]
|
||||
context_documents: list[InferenceSection]
|
||||
|
||||
|
||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
|
||||
@@ -12,13 +12,6 @@ def dedup_inference_sections(
|
||||
return deduped
|
||||
|
||||
|
||||
def dedup_inference_section_list(
|
||||
list: list[InferenceSection],
|
||||
) -> list[InferenceSection]:
|
||||
deduped = _merge_sections(list)
|
||||
return deduped
|
||||
|
||||
|
||||
def dedup_question_answer_results(
|
||||
question_answer_results_1: list[SubQuestionAnswerResults],
|
||||
question_answer_results_2: list[SubQuestionAnswerResults],
|
||||
|
||||
@@ -20,18 +20,10 @@ from onyx.agents.agent_search.models import GraphInputs
|
||||
from onyx.agents.agent_search.models import GraphPersistence
|
||||
from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import PersonaPromptExpressions
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
StructuredSubquestionDocuments,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import SubQuestionAnswerResults
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationConfig
|
||||
@@ -42,10 +34,6 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
@@ -58,8 +46,6 @@ from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.persona import Persona
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.prompts.agent_search import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
@@ -80,10 +66,8 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
|
||||
|
||||
# Post-processing
|
||||
@@ -396,26 +380,8 @@ def summarize_history(
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
history_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
llm.invoke,
|
||||
history_context_prompt,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
logger.error("LLM Timeout Error - summarize history")
|
||||
return (
|
||||
history # this is what is done at this point anyway, so we default to this
|
||||
)
|
||||
except LLMRateLimitError:
|
||||
logger.error("LLM Rate Limit Error - summarize history")
|
||||
return (
|
||||
history # this is what is done at this point anyway, so we default to this
|
||||
)
|
||||
|
||||
history_response = llm.invoke(history_context_prompt)
|
||||
assert isinstance(history_response.content, str)
|
||||
|
||||
return history_response.content
|
||||
|
||||
|
||||
@@ -481,27 +447,3 @@ def remove_document_citations(text: str) -> str:
|
||||
# \d+ - one or more digits
|
||||
# \] - literal ] character
|
||||
return re.sub(r"\[(?:D|Q)?\d+\]", "", text)
|
||||
|
||||
|
||||
def get_deduplicated_structured_subquestion_documents(
|
||||
sub_question_results: list[SubQuestionAnswerResults],
|
||||
) -> StructuredSubquestionDocuments:
|
||||
"""
|
||||
Extract and deduplicate all cited documents from sub-question results.
|
||||
|
||||
Args:
|
||||
sub_question_results: List of sub-question results containing cited documents
|
||||
|
||||
Returns:
|
||||
Deduplicated list of cited documents
|
||||
"""
|
||||
cited_docs = [
|
||||
doc for result in sub_question_results for doc in result.cited_documents
|
||||
]
|
||||
context_docs = [
|
||||
doc for result in sub_question_results for doc in result.context_documents
|
||||
]
|
||||
return StructuredSubquestionDocuments(
|
||||
cited_documents=dedup_inference_section_list(cited_docs),
|
||||
context_documents=dedup_inference_section_list(context_docs),
|
||||
)
|
||||
|
||||
@@ -94,7 +94,6 @@ from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
@@ -108,6 +107,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
|
||||
@@ -36,15 +36,6 @@ beat_task_templates.extend(
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-checkpoint-cleanup",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import Any
|
||||
@@ -16,7 +15,6 @@ from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.result import AsyncResult
|
||||
from celery.states import READY_STATES
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -28,13 +26,7 @@ from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attem
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
|
||||
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
|
||||
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
|
||||
from onyx.background.indexing.checkpointing_utils import (
|
||||
get_index_attempts_with_old_checkpoints,
|
||||
)
|
||||
from onyx.background.indexing.job_client import SimpleJob
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.job_client import SimpleJobException
|
||||
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
@@ -42,7 +34,6 @@ from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
@@ -79,123 +70,6 @@ from shared_configs.configs import SENTRY_DSN
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
"""The different statuses the watchdog can finish with.
|
||||
|
||||
TODO: create broader success/failure/abort categories
|
||||
"""
|
||||
|
||||
UNDEFINED = "undefined"
|
||||
|
||||
SUCCEEDED = "succeeded"
|
||||
|
||||
SPAWN_FAILED = "spawn_failed" # connector spawn failed
|
||||
|
||||
BLOCKED_BY_DELETION = "blocked_by_deletion"
|
||||
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
|
||||
FENCE_NOT_FOUND = "fence_not_found" # fence does not exist
|
||||
FENCE_READINESS_TIMEOUT = (
|
||||
"fence_readiness_timeout" # fence exists but wasn't ready within the timeout
|
||||
)
|
||||
FENCE_MISMATCH = "fence_mismatch" # task and fence metadata mismatch
|
||||
TASK_ALREADY_RUNNING = "task_already_running" # task appears to be running already
|
||||
INDEX_ATTEMPT_MISMATCH = (
|
||||
"index_attempt_mismatch" # expected index attempt metadata not found in db
|
||||
)
|
||||
|
||||
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
|
||||
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
|
||||
|
||||
# the watchdog received a termination signal
|
||||
TERMINATED_BY_SIGNAL = "terminated_by_signal"
|
||||
|
||||
# the watchdog terminated the task due to no activity
|
||||
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
|
||||
|
||||
OUT_OF_MEMORY = "out_of_memory"
|
||||
|
||||
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
|
||||
|
||||
@property
|
||||
def code(self) -> int:
|
||||
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
|
||||
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
|
||||
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
|
||||
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
|
||||
IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT: 251,
|
||||
IndexingWatchdogTerminalStatus.FENCE_MISMATCH: 252,
|
||||
IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING: 253,
|
||||
IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH: 254,
|
||||
IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED: 255,
|
||||
}
|
||||
|
||||
return _ENUM_TO_CODE[self]
|
||||
|
||||
@classmethod
|
||||
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
|
||||
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
|
||||
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
|
||||
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
|
||||
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
|
||||
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
|
||||
251: IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT,
|
||||
252: IndexingWatchdogTerminalStatus.FENCE_MISMATCH,
|
||||
253: IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING,
|
||||
254: IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH,
|
||||
255: IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED,
|
||||
}
|
||||
|
||||
if code in _CODE_TO_ENUM:
|
||||
return _CODE_TO_ENUM[code]
|
||||
|
||||
return IndexingWatchdogTerminalStatus.UNDEFINED
|
||||
|
||||
|
||||
class SimpleJobResult:
|
||||
"""The data we want to have when the watchdog finishes"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.status = IndexingWatchdogTerminalStatus.UNDEFINED
|
||||
self.connector_source = None
|
||||
self.exit_code = None
|
||||
self.exception_str = None
|
||||
|
||||
status: IndexingWatchdogTerminalStatus
|
||||
connector_source: str | None
|
||||
exit_code: int | None
|
||||
exception_str: str | None
|
||||
|
||||
|
||||
class ConnectorIndexingContext(BaseModel):
|
||||
tenant_id: str | None
|
||||
cc_pair_id: int
|
||||
search_settings_id: int
|
||||
index_attempt_id: int
|
||||
|
||||
|
||||
class ConnectorIndexingLogBuilder:
|
||||
def __init__(self, ctx: ConnectorIndexingContext):
|
||||
self.ctx = ctx
|
||||
|
||||
def build(self, msg: str, **kwargs: Any) -> str:
|
||||
msg_final = (
|
||||
f"{msg}: "
|
||||
f"tenant_id={self.ctx.tenant_id} "
|
||||
f"attempt={self.ctx.index_attempt_id} "
|
||||
f"cc_pair={self.ctx.cc_pair_id} "
|
||||
f"search_settings={self.ctx.search_settings_id}"
|
||||
)
|
||||
|
||||
# Append extra keyword arguments in logfmt style
|
||||
if kwargs:
|
||||
extra_logfmt = " ".join(f"{key}={value}" for key, value in kwargs.items())
|
||||
msg_final = f"{msg_final} {extra_logfmt}"
|
||||
|
||||
return msg_final
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
@@ -622,6 +496,7 @@ def connector_indexing_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
attempt_found = False
|
||||
n_final_progress: int | None = None
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
@@ -638,21 +513,19 @@ def connector_indexing_task(
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if redis_connector.delete.fenced:
|
||||
raise SimpleJobException(
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because connector deletion is in progress: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.delete.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION.code,
|
||||
f"fence={redis_connector.delete.fence_key}"
|
||||
)
|
||||
|
||||
if redis_connector.stop.fenced:
|
||||
raise SimpleJobException(
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because a connector stop signal was detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.stop.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
|
||||
f"fence={redis_connector.stop.fence_key}"
|
||||
)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
@@ -661,24 +534,19 @@ def connector_indexing_task(
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise SimpleJobException(
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT.code,
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector_index.fenced: # The fence must exist
|
||||
raise SimpleJobException(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector_index.payload # The payload must exist
|
||||
if not payload:
|
||||
raise SimpleJobException(
|
||||
"connector_indexing_task: payload invalid or not found",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
|
||||
)
|
||||
raise ValueError("connector_indexing_task: payload invalid or not found")
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
logger.info(
|
||||
@@ -688,11 +556,10 @@ def connector_indexing_task(
|
||||
continue
|
||||
|
||||
if payload.index_attempt_id != index_attempt_id:
|
||||
raise SimpleJobException(
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - id mismatch. Task may be left over from previous run.: "
|
||||
f"task_index_attempt={index_attempt_id} "
|
||||
f"payload_index_attempt={payload.index_attempt_id}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
|
||||
f"payload_index_attempt={payload.index_attempt_id}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -716,14 +583,7 @@ def connector_indexing_task(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
raise SimpleJobException(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}",
|
||||
code=IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING.code,
|
||||
)
|
||||
return None
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector_index.set_fence(payload)
|
||||
@@ -732,10 +592,10 @@ def connector_indexing_task(
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
raise SimpleJobException(
|
||||
f"Index attempt not found: index_attempt={index_attempt_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
raise ValueError(
|
||||
f"Index attempt not found: index_attempt={index_attempt_id}"
|
||||
)
|
||||
attempt_found = True
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
@@ -743,21 +603,16 @@ def connector_indexing_task(
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
raise SimpleJobException(
|
||||
f"cc_pair not found: cc_pair={cc_pair_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
)
|
||||
raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}")
|
||||
|
||||
if not cc_pair.connector:
|
||||
raise SimpleJobException(
|
||||
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
raise ValueError(
|
||||
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}"
|
||||
)
|
||||
|
||||
if not cc_pair.credential:
|
||||
raise SimpleJobException(
|
||||
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
raise ValueError(
|
||||
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
|
||||
# define a callback class
|
||||
@@ -795,6 +650,20 @@ def connector_indexing_task(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
if attempt_found:
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id, db_session, failure_reason=str(e)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
raise e
|
||||
finally:
|
||||
@@ -809,49 +678,41 @@ def connector_indexing_task(
|
||||
return n_final_progress
|
||||
|
||||
|
||||
def process_job_result(
|
||||
job: SimpleJob,
|
||||
connector_source: str | None,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
log_builder: ConnectorIndexingLogBuilder,
|
||||
) -> SimpleJobResult:
|
||||
result = SimpleJobResult()
|
||||
result.connector_source = connector_source
|
||||
def connector_indexing_task_wrapper(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
is_ee: bool,
|
||||
) -> int | None:
|
||||
"""Just wraps connector_indexing_task so we can log any exceptions before
|
||||
re-raising it."""
|
||||
result: int | None = None
|
||||
|
||||
if job.process:
|
||||
result.exit_code = job.process.exitcode
|
||||
|
||||
if job.status != "error":
|
||||
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
|
||||
return result
|
||||
|
||||
ignore_exitcode = False
|
||||
|
||||
# In EKS, there is an edge case where successful tasks return exit
|
||||
# code 1 in the cloud due to the set_spawn_method not sticking.
|
||||
# We've since worked around this, but the following is a safe way to
|
||||
# work around this issue. Basically, we ignore the job error state
|
||||
# if the completion signal is OK.
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if ignore_exitcode:
|
||||
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...",
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
try:
|
||||
result = connector_indexing_task(
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
tenant_id,
|
||||
is_ee,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"connector_indexing_task exceptioned: "
|
||||
f"tenant={tenant_id} "
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
else:
|
||||
if result.exit_code is not None:
|
||||
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
|
||||
|
||||
result.exception_str = job.exception()
|
||||
# There is a cloud related bug outside of our code
|
||||
# where spawned tasks return with an exit code of 1.
|
||||
# Unfortunately, exceptions also return with an exit code of 1,
|
||||
# so just raising an exception isn't informative
|
||||
# Exiting with 255 makes it possible to distinguish between normal exits
|
||||
# and exceptions.
|
||||
sys.exit(255)
|
||||
|
||||
return result
|
||||
|
||||
@@ -869,32 +730,12 @@ def connector_indexing_proxy_task(
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
|
||||
and forking is inherently unstable.
|
||||
|
||||
To work around this, we use pool=threads and proxy our work to a spawned task.
|
||||
|
||||
TODO(rkuo): refactor this so that there is a single return path where we canonically
|
||||
log the result of running this function.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
result = SimpleJobResult()
|
||||
|
||||
ctx = ConnectorIndexingContext(
|
||||
tenant_id=tenant_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=search_settings_id,
|
||||
index_attempt_id=index_attempt_id,
|
||||
)
|
||||
|
||||
log_builder = ConnectorIndexingLogBuilder(ctx)
|
||||
|
||||
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - starting",
|
||||
mp_start_method=str(multiprocessing.get_start_method()),
|
||||
)
|
||||
f"Indexing watchdog - starting: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"mp_start_method={multiprocessing.get_start_method()}"
|
||||
)
|
||||
|
||||
if not self.request.id:
|
||||
@@ -903,7 +744,7 @@ def connector_indexing_proxy_task(
|
||||
client = SimpleJobClient()
|
||||
|
||||
job = client.submit(
|
||||
connector_indexing_task,
|
||||
connector_indexing_task_wrapper,
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
@@ -913,223 +754,139 @@ def connector_indexing_proxy_task(
|
||||
)
|
||||
|
||||
if not job:
|
||||
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return
|
||||
|
||||
task_logger.info(log_builder.build("Indexing watchdog - spawn succeeded"))
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
if not index_attempt:
|
||||
raise RuntimeError("Index attempt not found")
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
result.connector_source = (
|
||||
index_attempt.connector_credential_pair.connector.source.value
|
||||
)
|
||||
# renew watchdog signal (this has a shorter timeout than set_active)
|
||||
redis_connector_index.set_watchdog(True)
|
||||
|
||||
while True:
|
||||
sleep(5)
|
||||
# renew active signal
|
||||
redis_connector_index.set_active()
|
||||
|
||||
# renew watchdog signal (this has a shorter timeout than set_active)
|
||||
redis_connector_index.set_watchdog(True)
|
||||
# if the job is done, clean up and break
|
||||
if job.done():
|
||||
exit_code: int | None
|
||||
try:
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
|
||||
# renew active signal
|
||||
redis_connector_index.set_active()
|
||||
exit_code = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# if the job is done, clean up and break
|
||||
if job.done():
|
||||
try:
|
||||
result = process_job_result(
|
||||
job, result.connector_source, redis_connector_index, log_builder
|
||||
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
|
||||
# even though logging clearly indicates successful completion
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if not ignore_exitcode:
|
||||
raise RuntimeError("Spawned task exceptioned.")
|
||||
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawned task exceptioned"
|
||||
)
|
||||
)
|
||||
finally:
|
||||
job.release()
|
||||
break
|
||||
|
||||
# if a termination signal is detected, clean up and break
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
log_builder.build("Indexing watchdog - termination signal detected")
|
||||
except Exception:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
result.status = IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL
|
||||
break
|
||||
raise
|
||||
finally:
|
||||
job.release()
|
||||
|
||||
break
|
||||
|
||||
# if a termination signal is detected, clean up and break
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - termination signal detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception looking up index attempt"
|
||||
)
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
except Exception:
|
||||
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
|
||||
result.exception_str = traceback.format_exc()
|
||||
|
||||
# handle exit and reporting
|
||||
elapsed = time.monotonic() - start
|
||||
if result.exception_str is not None:
|
||||
# print with exception
|
||||
job.cancel()
|
||||
break
|
||||
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
failure_reason = (
|
||||
f"Spawned task exceptioned: exit_code={result.exit_code}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
ctx.index_attempt_id,
|
||||
db_session,
|
||||
failure_reason=failure_reason,
|
||||
full_exception_trace=result.exception_str,
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as failed"
|
||||
)
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
normalized_exception_str = "None"
|
||||
if result.exception_str:
|
||||
normalized_exception_str = result.exception_str.replace(
|
||||
"\n", "\\n"
|
||||
).replace('"', '\\"')
|
||||
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
source=result.connector_source,
|
||||
status=result.status.value,
|
||||
exit_code=str(result.exit_code),
|
||||
exception=f'"{normalized_exception_str}"',
|
||||
elapsed=f"{elapsed:.2f}s",
|
||||
)
|
||||
)
|
||||
|
||||
redis_connector_index.set_watchdog(False)
|
||||
raise RuntimeError(f"Exception encountered: traceback={result.exception_str}")
|
||||
|
||||
# print without exception
|
||||
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
except Exception:
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled"
|
||||
)
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
source=result.connector_source,
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
elapsed=f"{elapsed:.2f}s",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
redis_connector_index.set_watchdog(False)
|
||||
return
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
|
||||
soft_time_limit=300,
|
||||
)
|
||||
def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
|
||||
"""Clean up old checkpoints that are older than 7 days."""
|
||||
locked = False
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - finished: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
locked = True
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_attempts = get_index_attempts_with_old_checkpoints(db_session)
|
||||
for attempt in old_attempts:
|
||||
task_logger.info(
|
||||
f"Cleaning up checkpoint for index attempt {attempt.id}"
|
||||
)
|
||||
cleanup_checkpoint_task.apply_async(
|
||||
kwargs={
|
||||
"index_attempt_id": attempt.id,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.CHECKPOINT_CLEANUP,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during checkpoint cleanup")
|
||||
return None
|
||||
finally:
|
||||
if locked:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_checkpoint_cleanup - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLEANUP_CHECKPOINT,
|
||||
bind=True,
|
||||
)
|
||||
def cleanup_checkpoint_task(
|
||||
self: Task, *, index_attempt_id: int, tenant_id: str | None
|
||||
) -> None:
|
||||
"""Clean up a checkpoint for a given index attempt"""
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
cleanup_checkpoint(db_session, index_attempt_id)
|
||||
return
|
||||
|
||||
@@ -240,8 +240,7 @@ def validate_indexing_fence(
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
f"validate_indexing_fence - "
|
||||
f"Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
)
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
@@ -190,9 +190,9 @@ def _build_connector_start_latency_metric(
|
||||
desired_start_time = cc_pair.connector.time_created
|
||||
else:
|
||||
if not cc_pair.connector.refresh_freq:
|
||||
task_logger.debug(
|
||||
"Connector has no refresh_freq and this is a non-initial index attempt. "
|
||||
"Assuming user manually triggered indexing, so we'll skip start latency metric."
|
||||
task_logger.error(
|
||||
"Found non-initial index attempt for connector "
|
||||
"without refresh_freq. This should never happen."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -105,7 +105,6 @@ def document_by_cc_pair_cleanup_task(
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
|
||||
@@ -78,10 +78,6 @@ logger = setup_logger()
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
# Useful for debugging timing issues with reacquisitions. TODO: remove once more generalized logging is in place
|
||||
task_logger.info("check_for_vespa_sync_task started")
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
@@ -496,21 +492,13 @@ def monitor_document_set_taskset(
|
||||
task_logger.info(
|
||||
f"Successfully synced document set: document_set={document_set_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"update_sync_record_status exceptioned. "
|
||||
f"document_set_id={document_set_id} "
|
||||
"Resetting document set regardless."
|
||||
)
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial_count,
|
||||
)
|
||||
|
||||
rds.reset()
|
||||
|
||||
|
||||
80
backend/onyx/background/indexing/checkpointing.py
Normal file
80
backend/onyx/background/indexing/checkpointing.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Experimental functionality related to splitting up indexing
|
||||
into a series of checkpoints to better handle intermittent failures
|
||||
/ jobs being killed by cloud providers."""
|
||||
import datetime
|
||||
|
||||
from onyx.configs.app_configs import EXPERIMENTAL_CHECKPOINTING_ENABLED
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
|
||||
|
||||
|
||||
def _2010_dt() -> datetime.datetime:
|
||||
return datetime.datetime(year=2010, month=1, day=1, tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
def _2020_dt() -> datetime.datetime:
|
||||
return datetime.datetime(year=2020, month=1, day=1, tzinfo=datetime.timezone.utc)
|
||||
|
||||
|
||||
def _default_end_time(
|
||||
last_successful_run: datetime.datetime | None,
|
||||
) -> datetime.datetime:
|
||||
"""If year is before 2010, go to the beginning of 2010.
|
||||
If year is 2010-2020, go in 5 year increments.
|
||||
If year > 2020, then go in 180 day increments.
|
||||
|
||||
For connectors that don't support a `filter_by` and instead rely on `sort_by`
|
||||
for polling, then this will cause a massive duplication of fetches. For these
|
||||
connectors, you may want to override this function to return a more reasonable
|
||||
plan (e.g. extending the 2020+ windows to 6 months, 1 year, or higher)."""
|
||||
last_successful_run = (
|
||||
datetime_to_utc(last_successful_run) if last_successful_run else None
|
||||
)
|
||||
if last_successful_run is None or last_successful_run < _2010_dt():
|
||||
return _2010_dt()
|
||||
|
||||
if last_successful_run < _2020_dt():
|
||||
return min(last_successful_run + datetime.timedelta(days=365 * 5), _2020_dt())
|
||||
|
||||
return last_successful_run + datetime.timedelta(days=180)
|
||||
|
||||
|
||||
def find_end_time_for_indexing_attempt(
|
||||
last_successful_run: datetime.datetime | None,
|
||||
# source_type can be used to override the default for certain connectors, currently unused
|
||||
source_type: DocumentSource,
|
||||
) -> datetime.datetime | None:
|
||||
"""Is the current time unless the connector is run over a large period, in which case it is
|
||||
split up into large time segments that become smaller as it approaches the present
|
||||
"""
|
||||
# NOTE: source_type can be used to override the default for certain connectors
|
||||
end_of_window = _default_end_time(last_successful_run)
|
||||
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
if end_of_window < now:
|
||||
return end_of_window
|
||||
|
||||
# None signals that we should index up to current time
|
||||
return None
|
||||
|
||||
|
||||
def get_time_windows_for_index_attempt(
|
||||
last_successful_run: datetime.datetime, source_type: DocumentSource
|
||||
) -> list[tuple[datetime.datetime, datetime.datetime]]:
|
||||
if not EXPERIMENTAL_CHECKPOINTING_ENABLED:
|
||||
return [(last_successful_run, datetime.datetime.now(tz=datetime.timezone.utc))]
|
||||
|
||||
time_windows: list[tuple[datetime.datetime, datetime.datetime]] = []
|
||||
start_of_window: datetime.datetime | None = last_successful_run
|
||||
while start_of_window:
|
||||
end_of_window = find_end_time_for_indexing_attempt(
|
||||
last_successful_run=start_of_window, source_type=source_type
|
||||
)
|
||||
time_windows.append(
|
||||
(
|
||||
start_of_window,
|
||||
end_of_window or datetime.datetime.now(tz=datetime.timezone.utc),
|
||||
)
|
||||
)
|
||||
start_of_window = end_of_window
|
||||
|
||||
return time_windows
|
||||
@@ -1,200 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from io import BytesIO
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.object_size_check import deep_getsizeof
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_NUM_RECENT_ATTEMPTS_TO_CONSIDER = 20
|
||||
_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT = 100
|
||||
|
||||
|
||||
def _build_checkpoint_pointer(index_attempt_id: int) -> str:
|
||||
return f"checkpoint_{index_attempt_id}.json"
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
db_session: Session, index_attempt_id: int, checkpoint: ConnectorCheckpoint
|
||||
) -> str:
|
||||
"""Save a checkpoint for a given index attempt to the file store"""
|
||||
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.save_file(
|
||||
file_name=checkpoint_pointer,
|
||||
content=BytesIO(checkpoint.model_dump_json().encode()),
|
||||
display_name=checkpoint_pointer,
|
||||
file_origin=FileOrigin.INDEXING_CHECKPOINT,
|
||||
file_type="application/json",
|
||||
)
|
||||
|
||||
index_attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
index_attempt.checkpoint_pointer = checkpoint_pointer
|
||||
db_session.add(index_attempt)
|
||||
db_session.commit()
|
||||
return checkpoint_pointer
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
db_session: Session, index_attempt_id: int
|
||||
) -> ConnectorCheckpoint | None:
|
||||
"""Load a checkpoint for a given index attempt from the file store"""
|
||||
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
|
||||
file_store = get_default_file_store(db_session)
|
||||
try:
|
||||
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
|
||||
checkpoint_data = checkpoint_io.read().decode("utf-8")
|
||||
return ConnectorCheckpoint.model_validate_json(checkpoint_data)
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
|
||||
def get_latest_valid_checkpoint(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
) -> ConnectorCheckpoint:
|
||||
"""Get the latest valid checkpoint for a given connector credential pair"""
|
||||
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=search_settings_id,
|
||||
db_session=db_session,
|
||||
limit=_NUM_RECENT_ATTEMPTS_TO_CONSIDER,
|
||||
)
|
||||
checkpoint_candidates = [
|
||||
candidate
|
||||
for candidate in checkpoint_candidates
|
||||
if (
|
||||
candidate.poll_range_start == window_start
|
||||
and candidate.poll_range_end == window_end
|
||||
and candidate.status == IndexingStatus.FAILED
|
||||
and candidate.checkpoint_pointer is not None
|
||||
# we want to make sure that the checkpoint is actually useful
|
||||
# if it's only gone through a few docs, it's probably not worth
|
||||
# using. This also avoids weird cases where a connector is basically
|
||||
# non-functional but still "makes progress" by slowly moving the
|
||||
# checkpoint forward run after run
|
||||
and candidate.total_docs_indexed
|
||||
and candidate.total_docs_indexed > _NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT
|
||||
)
|
||||
]
|
||||
|
||||
# don't keep using checkpoints if we've had a bunch of failed attempts in a row
|
||||
# for now, capped at 10
|
||||
if len(checkpoint_candidates) == _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
|
||||
logger.warning(
|
||||
f"{_NUM_RECENT_ATTEMPTS_TO_CONSIDER} consecutive failed attempts found "
|
||||
f"for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
|
||||
"from scratch."
|
||||
)
|
||||
return ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
|
||||
# assumes latest checkpoint is the furthest along. This only isn't true
|
||||
# if something else has gone wrong.
|
||||
latest_valid_checkpoint_candidate = (
|
||||
checkpoint_candidates[0] if checkpoint_candidates else None
|
||||
)
|
||||
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
if latest_valid_checkpoint_candidate:
|
||||
try:
|
||||
previous_checkpoint = load_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=latest_valid_checkpoint_candidate.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to load checkpoint from previous failed attempt with ID "
|
||||
f"{latest_valid_checkpoint_candidate.id}."
|
||||
)
|
||||
previous_checkpoint = None
|
||||
|
||||
if previous_checkpoint is not None:
|
||||
logger.info(
|
||||
f"Using checkpoint from previous failed attempt with ID "
|
||||
f"{latest_valid_checkpoint_candidate.id}. Previous checkpoint: "
|
||||
f"{previous_checkpoint}"
|
||||
)
|
||||
save_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=latest_valid_checkpoint_candidate.id,
|
||||
checkpoint=previous_checkpoint,
|
||||
)
|
||||
checkpoint = previous_checkpoint
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def get_index_attempts_with_old_checkpoints(
|
||||
db_session: Session, days_to_keep: int = 7
|
||||
) -> list[IndexAttempt]:
|
||||
"""Get all index attempts with checkpoints older than the specified number of days.
|
||||
|
||||
Args:
|
||||
db_session: The database session
|
||||
days_to_keep: Number of days to keep checkpoints for (default: 7)
|
||||
|
||||
Returns:
|
||||
Number of checkpoints deleted
|
||||
"""
|
||||
cutoff_date = get_db_current_time(db_session) - timedelta(days=days_to_keep)
|
||||
|
||||
# Find all index attempts with checkpoints older than cutoff_date
|
||||
old_attempts = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
and_(
|
||||
IndexAttempt.checkpoint_pointer.isnot(None),
|
||||
IndexAttempt.time_created < cutoff_date,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return old_attempts
|
||||
|
||||
|
||||
def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
|
||||
"""Clean up a checkpoint for a given index attempt"""
|
||||
index_attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
|
||||
if not index_attempt.checkpoint_pointer:
|
||||
return None
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.delete_file(index_attempt.checkpoint_pointer)
|
||||
|
||||
index_attempt.checkpoint_pointer = None
|
||||
db_session.add(index_attempt)
|
||||
db_session.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None:
|
||||
"""Check if the checkpoint content size exceeds the limit (200MB)"""
|
||||
content_size = deep_getsizeof(checkpoint.checkpoint_content)
|
||||
if content_size > 200_000_000: # 200MB in bytes
|
||||
raise ValueError(
|
||||
f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit"
|
||||
)
|
||||
@@ -5,8 +5,6 @@ not follow the expected behavior, etc.
|
||||
NOTE: cannot use Celery directly due to
|
||||
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
||||
import multiprocessing as mp
|
||||
import sys
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.context import SpawnProcess
|
||||
@@ -20,16 +18,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class SimpleJobException(Exception):
|
||||
"""lets us raise an exception that will return a specific error code"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
code: int | None = kwargs.pop("code", None)
|
||||
self.code = code
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
JobStatusType = (
|
||||
Literal["error"]
|
||||
| Literal["finished"]
|
||||
@@ -40,10 +28,7 @@ JobStatusType = (
|
||||
|
||||
|
||||
def _initializer(
|
||||
func: Callable,
|
||||
queue: mp.Queue,
|
||||
args: list | tuple,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Initialize the child process with a fresh SQLAlchemy Engine.
|
||||
|
||||
@@ -67,29 +52,13 @@ def _initializer(
|
||||
)
|
||||
|
||||
# Proceed with executing the target function
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except SimpleJobException as e:
|
||||
logger.exception("SimpleJob raised a SimpleJobException")
|
||||
error_msg = traceback.format_exc()
|
||||
queue.put(error_msg) # Send the exception to the parent process
|
||||
|
||||
sys.exit(e.code) # use the given exit code
|
||||
except Exception:
|
||||
logger.exception("SimpleJob raised an exception")
|
||||
error_msg = traceback.format_exc()
|
||||
queue.put(error_msg) # Send the exception to the parent process
|
||||
|
||||
sys.exit(255) # use 255 to indicate a generic exception
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def _run_in_process(
|
||||
func: Callable,
|
||||
queue: mp.Queue,
|
||||
args: list | tuple,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
_initializer(func, queue, args, kwargs)
|
||||
_initializer(func, args, kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -98,8 +67,6 @@ class SimpleJob:
|
||||
|
||||
id: int
|
||||
process: Optional["SpawnProcess"] = None
|
||||
queue: Optional[mp.Queue] = None
|
||||
_exception: Optional[str] = None
|
||||
|
||||
def cancel(self) -> bool:
|
||||
return self.release()
|
||||
@@ -133,15 +100,9 @@ class SimpleJob:
|
||||
def exception(self) -> str:
|
||||
"""Needed to match the Dask API, but not implemented since we don't currently
|
||||
have a way to get back the exception information from the child process."""
|
||||
|
||||
"""Retrieve exception from the multiprocessing queue if available."""
|
||||
if self._exception is None and self.queue and not self.queue.empty():
|
||||
self._exception = self.queue.get() # Get exception from queue
|
||||
|
||||
if self._exception:
|
||||
return self._exception
|
||||
|
||||
return f"Job with ID '{self.id}' did not report an exception."
|
||||
return (
|
||||
f"Job with ID '{self.id}' was killed or encountered an unhandled exception."
|
||||
)
|
||||
|
||||
|
||||
class SimpleJobClient:
|
||||
@@ -176,11 +137,8 @@ class SimpleJobClient:
|
||||
# this approach allows us to always "spawn" a new process regardless of
|
||||
# get_start_method's current setting
|
||||
ctx = mp.get_context("spawn")
|
||||
queue = ctx.Queue()
|
||||
process = ctx.Process(
|
||||
target=_run_in_process, args=(func, queue, args), daemon=True
|
||||
)
|
||||
job = SimpleJob(id=job_id, process=process, queue=queue)
|
||||
process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True)
|
||||
job = SimpleJob(id=job_id, process=process)
|
||||
process.start()
|
||||
|
||||
self.jobs[job_id] = job
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
import tracemalloc
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DANSWER_TRACEMALLOC_FRAMES = 10
|
||||
|
||||
|
||||
class MemoryTracer:
|
||||
def __init__(self, interval: int = 0, num_print_entries: int = 5):
|
||||
self.interval = interval
|
||||
self.num_print_entries = num_print_entries
|
||||
self.snapshot_first: tracemalloc.Snapshot | None = None
|
||||
self.snapshot_prev: tracemalloc.Snapshot | None = None
|
||||
self.snapshot: tracemalloc.Snapshot | None = None
|
||||
self.counter = 0
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the memory tracer if interval is greater than 0."""
|
||||
if self.interval > 0:
|
||||
logger.debug(f"Memory tracer starting: interval={self.interval}")
|
||||
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
|
||||
self._take_snapshot()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the memory tracer if it's running."""
|
||||
if self.interval > 0:
|
||||
self.log_final_diff()
|
||||
tracemalloc.stop()
|
||||
logger.debug("Memory tracer stopped.")
|
||||
|
||||
def _take_snapshot(self) -> None:
|
||||
"""Take a snapshot and update internal snapshot states."""
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
# Filter out irrelevant frames
|
||||
snapshot = snapshot.filter_traces(
|
||||
(
|
||||
tracemalloc.Filter(False, tracemalloc.__file__),
|
||||
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
|
||||
tracemalloc.Filter(False, "<frozen importlib._bootstrap_external>"),
|
||||
)
|
||||
)
|
||||
|
||||
if not self.snapshot_first:
|
||||
self.snapshot_first = snapshot
|
||||
|
||||
if self.snapshot:
|
||||
self.snapshot_prev = self.snapshot
|
||||
|
||||
self.snapshot = snapshot
|
||||
|
||||
def _log_diff(
|
||||
self, current: tracemalloc.Snapshot, previous: tracemalloc.Snapshot
|
||||
) -> None:
|
||||
"""Log the memory difference between two snapshots."""
|
||||
stats = current.compare_to(previous, "traceback")
|
||||
for s in stats[: self.num_print_entries]:
|
||||
logger.debug(f"Tracer diff: {s}")
|
||||
for line in s.traceback.format():
|
||||
logger.debug(f"* {line}")
|
||||
|
||||
def increment_and_maybe_trace(self) -> None:
|
||||
"""Increment counter and perform trace if interval is hit."""
|
||||
if self.interval <= 0:
|
||||
return
|
||||
|
||||
self.counter += 1
|
||||
if self.counter % self.interval == 0:
|
||||
logger.debug(
|
||||
f"Running trace comparison for batch {self.counter}. interval={self.interval}"
|
||||
)
|
||||
self._take_snapshot()
|
||||
if self.snapshot and self.snapshot_prev:
|
||||
self._log_diff(self.snapshot, self.snapshot_prev)
|
||||
|
||||
def log_final_diff(self) -> None:
|
||||
"""Log the final memory diff between start and end of indexing."""
|
||||
if self.interval <= 0:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Running trace comparison between start and end of indexing. {self.counter} batches processed."
|
||||
)
|
||||
self._take_snapshot()
|
||||
if self.snapshot and self.snapshot_first:
|
||||
self._log_diff(self.snapshot, self.snapshot_first)
|
||||
@@ -1,40 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.models import IndexAttemptError
|
||||
|
||||
|
||||
class IndexAttemptErrorPydantic(BaseModel):
|
||||
id: int
|
||||
connector_credential_pair_id: int
|
||||
|
||||
document_id: str | None
|
||||
document_link: str | None
|
||||
|
||||
entity_id: str | None
|
||||
failed_time_range_start: datetime | None
|
||||
failed_time_range_end: datetime | None
|
||||
|
||||
failure_message: str
|
||||
is_resolved: bool = False
|
||||
|
||||
time_created: datetime
|
||||
|
||||
index_attempt_id: int
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic":
|
||||
return cls(
|
||||
id=model.id,
|
||||
connector_credential_pair_id=model.connector_credential_pair_id,
|
||||
document_id=model.document_id,
|
||||
document_link=model.document_link,
|
||||
entity_id=model.entity_id,
|
||||
failed_time_range_start=model.failed_time_range_start,
|
||||
failed_time_range_end=model.failed_time_range_end,
|
||||
failure_message=model.failure_message,
|
||||
is_resolved=model.is_resolved,
|
||||
time_created=model.time_created,
|
||||
index_attempt_id=model.index_attempt_id,
|
||||
)
|
||||
@@ -1,6 +1,5 @@
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -8,11 +7,8 @@ from datetime import timezone
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.indexing.checkpointing_utils import check_checkpoint_size
|
||||
from onyx.background.indexing.checkpointing_utils import get_latest_valid_checkpoint
|
||||
from onyx.background.indexing.checkpointing_utils import save_checkpoint
|
||||
from onyx.background.indexing.memory_tracer import MemoryTracer
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
from onyx.background.indexing.tracer import OnyxTracer
|
||||
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
|
||||
@@ -21,8 +17,6 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -30,18 +24,15 @@ from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from onyx.db.index_attempt import mark_attempt_succeeded
|
||||
from onyx.db.index_attempt import transition_attempt_to_in_progress
|
||||
from onyx.db.index_attempt import update_docs_indexed
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
@@ -62,7 +53,6 @@ INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
|
||||
def _get_connector_runner(
|
||||
db_session: Session,
|
||||
attempt: IndexAttempt,
|
||||
batch_size: int,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
tenant_id: str | None,
|
||||
@@ -110,9 +100,7 @@ def _get_connector_runner(
|
||||
raise e
|
||||
|
||||
return ConnectorRunner(
|
||||
connector=runnable_connector,
|
||||
batch_size=batch_size,
|
||||
time_range=(start_time, end_time),
|
||||
connector=runnable_connector, time_range=(start_time, end_time)
|
||||
)
|
||||
|
||||
|
||||
@@ -171,66 +159,6 @@ class RunIndexingContext(BaseModel):
|
||||
search_settings_status: IndexModelStatus
|
||||
|
||||
|
||||
def _check_connector_and_attempt_status(
|
||||
db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int
|
||||
) -> None:
|
||||
"""
|
||||
Checks the status of the connector credential pair and index attempt.
|
||||
Raises a RuntimeError if any conditions are not met.
|
||||
"""
|
||||
cc_pair_loop = get_connector_credential_pair_from_id(
|
||||
db_session_temp,
|
||||
ctx.cc_pair_id,
|
||||
)
|
||||
if not cc_pair_loop:
|
||||
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
|
||||
|
||||
if (
|
||||
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and ctx.search_settings_status != IndexModelStatus.FUTURE
|
||||
) or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING:
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
|
||||
index_attempt_loop = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt_loop:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
|
||||
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
|
||||
raise RuntimeError(
|
||||
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
|
||||
)
|
||||
|
||||
|
||||
def _check_failure_threshold(
|
||||
total_failures: int,
|
||||
document_count: int,
|
||||
batch_num: int,
|
||||
last_failure: ConnectorFailure | None,
|
||||
) -> None:
|
||||
"""Check if we've hit the failure threshold and raise an appropriate exception if so.
|
||||
|
||||
We consider the threshold hit if:
|
||||
1. We have more than 3 failures AND
|
||||
2. Failures account for more than 10% of processed documents
|
||||
"""
|
||||
failure_ratio = total_failures / (document_count or 1)
|
||||
|
||||
FAILURE_THRESHOLD = 3
|
||||
FAILURE_RATIO_THRESHOLD = 0.1
|
||||
if total_failures > FAILURE_THRESHOLD and failure_ratio > FAILURE_RATIO_THRESHOLD:
|
||||
logger.error(
|
||||
f"Connector run failed with '{total_failures}' errors "
|
||||
f"after '{batch_num}' batches."
|
||||
)
|
||||
if last_failure and last_failure.exception:
|
||||
raise last_failure.exception from last_failure.exception
|
||||
|
||||
raise RuntimeError(
|
||||
f"Connector run encountered too many errors, aborting. "
|
||||
f"Last error: {last_failure}"
|
||||
)
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
@@ -241,8 +169,11 @@ def _run_indexing(
|
||||
1. Get documents which are either new or updated from specified application
|
||||
2. Embed and index these documents into the chosen datastore (vespa)
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
|
||||
TODO: do not change index attempt statuses here ... instead, set signals in redis
|
||||
and allow the monitor function to clean them up
|
||||
"""
|
||||
start_time = time.monotonic() # jsut used for logging
|
||||
start_time = time.time()
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
@@ -290,46 +221,6 @@ def _run_indexing(
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
)
|
||||
if last_successful_index_time > POLL_CONNECTOR_OFFSET:
|
||||
window_start = datetime.fromtimestamp(
|
||||
last_successful_index_time, tz=timezone.utc
|
||||
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
|
||||
else:
|
||||
# don't go into "negative" time if we've never indexed before
|
||||
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
|
||||
most_recent_attempt = next(
|
||||
iter(
|
||||
get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt_start.search_settings_id,
|
||||
db_session=db_session_temp,
|
||||
limit=1,
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
# if the last attempt failed, try and use the same window. This is necessary
|
||||
# to ensure correctness with checkpointing. If we don't do this, things like
|
||||
# new slack channels could be missed (since existing slack channels are
|
||||
# cached as part of the checkpoint).
|
||||
if (
|
||||
most_recent_attempt
|
||||
and most_recent_attempt.poll_range_end
|
||||
and (
|
||||
most_recent_attempt.status == IndexingStatus.FAILED
|
||||
or most_recent_attempt.status == IndexingStatus.CANCELED
|
||||
)
|
||||
):
|
||||
window_end = most_recent_attempt.poll_range_end
|
||||
else:
|
||||
window_end = datetime.now(tz=timezone.utc)
|
||||
|
||||
# add start/end now that they have been set
|
||||
index_attempt_start.poll_range_start = window_start
|
||||
index_attempt_start.poll_range_end = window_end
|
||||
db_session_temp.add(index_attempt_start)
|
||||
db_session_temp.commit()
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=index_attempt_start.search_settings,
|
||||
@@ -343,6 +234,7 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
attempt_id=index_attempt_id,
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=(
|
||||
@@ -354,73 +246,63 @@ def _run_indexing(
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# Initialize memory tracer. NOTE: won't actually do anything if
|
||||
# `INDEXING_TRACER_INTERVAL` is 0.
|
||||
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
|
||||
memory_tracer.start()
|
||||
tracer: OnyxTracer
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
|
||||
tracer = OnyxTracer()
|
||||
tracer.start()
|
||||
tracer.snap()
|
||||
|
||||
index_attempt_md = IndexAttemptMetadata(
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
)
|
||||
|
||||
total_failures = 0
|
||||
batch_num = 0
|
||||
net_doc_change = 0
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
run_end_dt = None
|
||||
tracer_counter: int
|
||||
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session_temp,
|
||||
attempt=index_attempt,
|
||||
batch_size=INDEX_BATCH_SIZE,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
for ind, (window_start, window_end) in enumerate(
|
||||
get_time_windows_for_index_attempt(
|
||||
last_successful_run=datetime.fromtimestamp(
|
||||
last_successful_index_time, tz=timezone.utc
|
||||
),
|
||||
source_type=db_connector.source,
|
||||
)
|
||||
):
|
||||
cc_pair_loop: ConnectorCredentialPair | None = None
|
||||
index_attempt_loop: IndexAttempt | None = None
|
||||
tracer_counter = 0
|
||||
|
||||
try:
|
||||
window_start = max(
|
||||
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
|
||||
datetime(1970, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
# the beginning in order to avoid weird interactions between
|
||||
# checkpointing / failure handling.
|
||||
if index_attempt.from_beginning:
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
else:
|
||||
checkpoint = get_latest_valid_checkpoint(
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt_loop_start = get_index_attempt(
|
||||
db_session_temp, index_attempt_id
|
||||
)
|
||||
if not index_attempt_loop_start:
|
||||
raise RuntimeError(
|
||||
f"Index attempt {index_attempt_id} not found in DB."
|
||||
)
|
||||
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session_temp,
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
attempt=index_attempt_loop_start,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
unresolved_errors = get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
unresolved_only=True,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
doc_id_to_unresolved_errors: dict[
|
||||
str, list[IndexAttemptError]
|
||||
] = defaultdict(list)
|
||||
for error in unresolved_errors:
|
||||
if error.document_id:
|
||||
doc_id_to_unresolved_errors[error.document_id].append(error)
|
||||
|
||||
entity_based_unresolved_errors = [
|
||||
error for error in unresolved_errors if error.entity_id
|
||||
]
|
||||
|
||||
while checkpoint.has_more:
|
||||
logger.info(
|
||||
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
|
||||
)
|
||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||
checkpoint
|
||||
):
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.snap()
|
||||
for doc_batch in connector_runner.run():
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
@@ -431,37 +313,41 @@ def _run_indexing(
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
# will exception if the connector/index attempt is marked as paused/failed
|
||||
_check_connector_and_attempt_status(
|
||||
db_session_temp, ctx, index_attempt_id
|
||||
cc_pair_loop = get_connector_credential_pair_from_id(
|
||||
db_session_temp,
|
||||
ctx.cc_pair_id,
|
||||
)
|
||||
if not cc_pair_loop:
|
||||
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
|
||||
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
total_failures += 1
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
failure,
|
||||
db_session_temp,
|
||||
if (
|
||||
(
|
||||
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and ctx.search_settings_status != IndexModelStatus.FUTURE
|
||||
)
|
||||
# if it's deleting, we don't care if this is a secondary index
|
||||
or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING
|
||||
):
|
||||
# let the `except` block handle this
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
|
||||
index_attempt_loop = get_index_attempt(
|
||||
db_session_temp, index_attempt_id
|
||||
)
|
||||
if not index_attempt_loop:
|
||||
raise RuntimeError(
|
||||
f"Index attempt {index_attempt_id} not found in DB."
|
||||
)
|
||||
|
||||
_check_failure_threshold(
|
||||
total_failures, document_count, batch_num, failure
|
||||
)
|
||||
|
||||
# save the new checkpoint (if one is provided)
|
||||
if next_checkpoint:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
# below is all document processing logic, so if no batch we can just continue
|
||||
if document_batch is None:
|
||||
continue
|
||||
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
|
||||
# Likely due to user manually disabling it or model swap
|
||||
raise RuntimeError(
|
||||
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
|
||||
)
|
||||
|
||||
batch_description = []
|
||||
|
||||
doc_batch_cleaned = strip_null_characters(document_batch)
|
||||
doc_batch_cleaned = strip_null_characters(doc_batch)
|
||||
for doc in doc_batch_cleaned:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
@@ -491,51 +377,15 @@ def _run_indexing(
|
||||
chunk_count += index_pipeline_result.total_chunks
|
||||
document_count += index_pipeline_result.total_docs
|
||||
|
||||
# resolve errors for documents that were successfully indexed
|
||||
failed_document_ids = [
|
||||
failure.failed_document.document_id
|
||||
for failure in index_pipeline_result.failures
|
||||
if failure.failed_document
|
||||
]
|
||||
successful_document_ids = [
|
||||
document.id
|
||||
for document in document_batch
|
||||
if document.id not in failed_document_ids
|
||||
]
|
||||
for document_id in successful_document_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
if document_id in doc_id_to_unresolved_errors:
|
||||
logger.info(
|
||||
f"Resolving IndexAttemptError for document '{document_id}'"
|
||||
)
|
||||
for error in doc_id_to_unresolved_errors[document_id]:
|
||||
error.is_resolved = True
|
||||
db_session_temp.add(error)
|
||||
db_session_temp.commit()
|
||||
|
||||
# add brand new failures
|
||||
if index_pipeline_result.failures:
|
||||
total_failures += len(index_pipeline_result.failures)
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
for failure in index_pipeline_result.failures:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
failure,
|
||||
db_session_temp,
|
||||
)
|
||||
|
||||
_check_failure_threshold(
|
||||
total_failures,
|
||||
document_count,
|
||||
batch_num,
|
||||
index_pipeline_result.failures[-1],
|
||||
)
|
||||
# commit transaction so that the `update` below begins
|
||||
# with a brand new transaction. Postgres uses the start
|
||||
# of the transactions when computing `NOW()`, so if we have
|
||||
# a long running transaction, the `time_updated` field will
|
||||
# be inaccurate
|
||||
db_session.commit()
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
|
||||
# so we need either to commit() or to use a new session
|
||||
update_docs_indexed(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
@@ -547,77 +397,126 @@ def _run_indexing(
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
tracer_counter += 1
|
||||
if (
|
||||
INDEXING_TRACER_INTERVAL > 0
|
||||
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
|
||||
):
|
||||
logger.debug(
|
||||
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
|
||||
)
|
||||
tracer.snap()
|
||||
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
|
||||
|
||||
# `make sure the checkpoints aren't getting too large`at some regular interval
|
||||
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
|
||||
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
|
||||
check_checkpoint_size(checkpoint)
|
||||
run_end_dt = window_end
|
||||
if ctx.is_primary:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
|
||||
# save latest checkpoint
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
save_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Connector run exceptioned after elapsed time: "
|
||||
f"{time.monotonic() - start_time} seconds"
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
# Otherwise, some progress was made - the next run will not start from the beginning.
|
||||
# In this case, it is not accurate to mark it as a failure. When the next run begins,
|
||||
# if that fails immediately, it will be marked as a failure.
|
||||
#
|
||||
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or (
|
||||
cc_pair_loop is not None and not cc_pair_loop.status.is_active()
|
||||
)
|
||||
or (
|
||||
index_attempt_loop is not None
|
||||
and index_attempt_loop.status != IndexingStatus.IN_PROGRESS
|
||||
)
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
|
||||
# break => similar to success case. As mentioned above, if the next run fails for the same
|
||||
# reason it will then be marked as a failure
|
||||
break
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
logger.debug(
|
||||
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
|
||||
)
|
||||
tracer.snap()
|
||||
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
|
||||
tracer.stop()
|
||||
logger.debug("Memory tracer stopped.")
|
||||
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
if (
|
||||
index_attempt_md.num_exceptions > 0
|
||||
and index_attempt_md.num_exceptions >= batch_num
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason="All batches exceptioned.",
|
||||
)
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
)
|
||||
raise Exception(
|
||||
f"Connector failed - All batches exceptioned: batches={batch_num}"
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
|
||||
memory_tracer.stop()
|
||||
|
||||
elapsed_time = time.monotonic() - start_time
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
# resolve entity-based errors
|
||||
for error in entity_based_unresolved_errors:
|
||||
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
|
||||
error.is_resolved = True
|
||||
db_session_temp.add(error)
|
||||
db_session_temp.commit()
|
||||
|
||||
if total_failures == 0:
|
||||
if index_attempt_md.num_exceptions == 0:
|
||||
mark_attempt_succeeded(index_attempt_id, db_session_temp)
|
||||
|
||||
create_milestone_and_report(
|
||||
@@ -636,7 +535,7 @@ def _run_indexing(
|
||||
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
|
||||
logger.info(
|
||||
f"Connector completed with some errors: "
|
||||
f"failures={total_failures} "
|
||||
f"exceptions={index_attempt_md.num_exceptions} "
|
||||
f"batches={batch_num} "
|
||||
f"docs={document_count} "
|
||||
f"chunks={chunk_count} "
|
||||
@@ -648,7 +547,7 @@ def _run_indexing(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
run_dt=window_end,
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
|
||||
|
||||
@@ -659,43 +558,46 @@ def run_indexing_entrypoint(
|
||||
is_ee: bool = False,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
"""Don't swallow exceptions here ... propagate them up."""
|
||||
try:
|
||||
if is_ee:
|
||||
global_version.set_ee()
|
||||
|
||||
if is_ee:
|
||||
global_version.set_ee()
|
||||
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
TaskAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# TODO: remove long running session entirely
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
tenant_str = ""
|
||||
if tenant_id is not None:
|
||||
tenant_str = f" for tenant {tenant_id}"
|
||||
|
||||
connector_name = attempt.connector_credential_pair.connector.name
|
||||
connector_config = (
|
||||
attempt.connector_credential_pair.connector.connector_specific_config
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
TaskAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
credential_id = attempt.connector_credential_pair.credential_id
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# TODO: remove long running session entirely
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
logger.info(
|
||||
f"Indexing starting{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
tenant_str = ""
|
||||
if tenant_id is not None:
|
||||
tenant_str = f" for tenant {tenant_id}"
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
|
||||
connector_name = attempt.connector_credential_pair.connector.name
|
||||
connector_config = (
|
||||
attempt.connector_credential_pair.connector.connector_specific_config
|
||||
)
|
||||
credential_id = attempt.connector_credential_pair.credential_id
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
logger.info(
|
||||
f"Indexing starting{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
|
||||
)
|
||||
|
||||
77
backend/onyx/background/indexing/tracer.py
Normal file
77
backend/onyx/background/indexing/tracer.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import tracemalloc
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DANSWER_TRACEMALLOC_FRAMES = 10
|
||||
|
||||
|
||||
class OnyxTracer:
|
||||
def __init__(self) -> None:
|
||||
self.snapshot_first: tracemalloc.Snapshot | None = None
|
||||
self.snapshot_prev: tracemalloc.Snapshot | None = None
|
||||
self.snapshot: tracemalloc.Snapshot | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
|
||||
|
||||
def stop(self) -> None:
|
||||
tracemalloc.stop()
|
||||
|
||||
def snap(self) -> None:
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
# Filter out irrelevant frames (e.g., from tracemalloc itself or importlib)
|
||||
snapshot = snapshot.filter_traces(
|
||||
(
|
||||
tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc
|
||||
tracemalloc.Filter(
|
||||
False, "<frozen importlib._bootstrap>"
|
||||
), # Exclude importlib
|
||||
tracemalloc.Filter(
|
||||
False, "<frozen importlib._bootstrap_external>"
|
||||
), # Exclude external importlib
|
||||
)
|
||||
)
|
||||
|
||||
if not self.snapshot_first:
|
||||
self.snapshot_first = snapshot
|
||||
|
||||
if self.snapshot:
|
||||
self.snapshot_prev = self.snapshot
|
||||
|
||||
self.snapshot = snapshot
|
||||
|
||||
def log_snapshot(self, numEntries: int) -> None:
|
||||
if not self.snapshot:
|
||||
return
|
||||
|
||||
stats = self.snapshot.statistics("traceback")
|
||||
for s in stats[:numEntries]:
|
||||
logger.debug(f"Tracer snap: {s}")
|
||||
for line in s.traceback:
|
||||
logger.debug(f"* {line}")
|
||||
|
||||
@staticmethod
|
||||
def log_diff(
|
||||
snap_current: tracemalloc.Snapshot,
|
||||
snap_previous: tracemalloc.Snapshot,
|
||||
numEntries: int,
|
||||
) -> None:
|
||||
stats = snap_current.compare_to(snap_previous, "traceback")
|
||||
for s in stats[:numEntries]:
|
||||
logger.debug(f"Tracer diff: {s}")
|
||||
for line in s.traceback.format():
|
||||
logger.debug(f"* {line}")
|
||||
|
||||
def log_previous_diff(self, numEntries: int) -> None:
|
||||
if not self.snapshot or not self.snapshot_prev:
|
||||
return
|
||||
|
||||
OnyxTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries)
|
||||
|
||||
def log_first_diff(self, numEntries: int) -> None:
|
||||
if not self.snapshot or not self.snapshot_first:
|
||||
return
|
||||
|
||||
OnyxTracer.log_diff(self.snapshot, self.snapshot_first, numEntries)
|
||||
@@ -27,10 +27,8 @@ from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.gpu_utils import gpu_status_request
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -82,26 +80,6 @@ class Answer:
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
rerank_settings = search_request.rerank_settings
|
||||
|
||||
using_cloud_reranking = (
|
||||
rerank_settings is not None
|
||||
and rerank_settings.rerank_provider_type is not None
|
||||
)
|
||||
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
|
||||
|
||||
# TODO: this is a hack to force the query to be used for the search tool
|
||||
# this should be removed once we fully unify graph inputs (i.e.
|
||||
# remove SearchQuery entirely)
|
||||
if (
|
||||
force_use_tool.force_use
|
||||
and search_tool
|
||||
and force_use_tool.args
|
||||
and force_use_tool.tool_name == search_tool.name
|
||||
and QUERY_FIELD in force_use_tool.args
|
||||
):
|
||||
search_request.query = force_use_tool.args[QUERY_FIELD]
|
||||
|
||||
self.graph_inputs = GraphInputs(
|
||||
search_request=search_request,
|
||||
prompt_builder=prompt_builder,
|
||||
@@ -116,6 +94,7 @@ class Answer:
|
||||
force_use_tool=force_use_tool,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
)
|
||||
assert db_session, "db_session must be provided for agentic persistence"
|
||||
self.graph_persistence = GraphPersistence(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -125,7 +104,6 @@ class Answer:
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
allow_refinement=True,
|
||||
allow_agent_reranking=allow_agent_reranking,
|
||||
)
|
||||
self.graph_config = GraphConfig(
|
||||
inputs=self.graph_inputs,
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import ToolCallException
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
|
||||
@@ -8,101 +8,14 @@ AGENT_DEFAULT_RERANKING_HITS = 10
|
||||
AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8
|
||||
AGENT_DEFAULT_NUM_DOCS_FOR_INITIAL_DECOMPOSITION = 3
|
||||
AGENT_DEFAULT_NUM_DOCS_FOR_REFINED_DECOMPOSITION = 5
|
||||
|
||||
AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER = 25
|
||||
AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER = 35
|
||||
|
||||
|
||||
AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS = 5
|
||||
AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
|
||||
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
|
||||
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
|
||||
|
||||
INITIAL_SEARCH_DECOMPOSITION_ENABLED = True
|
||||
ALLOW_REFINEMENT = True
|
||||
|
||||
AGENT_DEFAULT_RETRIEVAL_HITS = 15
|
||||
AGENT_DEFAULT_RERANKING_HITS = 10
|
||||
AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8
|
||||
AGENT_DEFAULT_NUM_DOCS_FOR_INITIAL_DECOMPOSITION = 3
|
||||
AGENT_DEFAULT_NUM_DOCS_FOR_REFINED_DECOMPOSITION = 5
|
||||
AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS = 5
|
||||
AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
|
||||
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
|
||||
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
|
||||
|
||||
AGENT_ANSWER_GENERATION_BY_FAST_LLM = (
|
||||
os.environ.get("AGENT_ANSWER_GENERATION_BY_FAST_LLM", "").lower() == "true"
|
||||
)
|
||||
|
||||
AGENT_RETRIEVAL_STATS = (
|
||||
not os.environ.get("AGENT_RETRIEVAL_STATS") == "False"
|
||||
) or True # default True
|
||||
|
||||
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int(
|
||||
os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS
|
||||
) # 15
|
||||
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int(
|
||||
os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS
|
||||
) # 15
|
||||
|
||||
# Reranking agent configs
|
||||
# Reranking stats - no influence on flow outside of stats collection
|
||||
AGENT_RERANKING_STATS = (
|
||||
not os.environ.get("AGENT_RERANKING_STATS") == "True"
|
||||
) or False # default False
|
||||
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int(
|
||||
os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS
|
||||
) # 15
|
||||
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS = int(
|
||||
os.environ.get("AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS")
|
||||
or AGENT_DEFAULT_RERANKING_HITS
|
||||
) # 10
|
||||
|
||||
AGENT_NUM_DOCS_FOR_DECOMPOSITION = int(
|
||||
os.environ.get("AGENT_NUM_DOCS_FOR_DECOMPOSITION")
|
||||
or AGENT_DEFAULT_NUM_DOCS_FOR_INITIAL_DECOMPOSITION
|
||||
) # 3
|
||||
|
||||
AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION = int(
|
||||
os.environ.get("AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION")
|
||||
or AGENT_DEFAULT_NUM_DOCS_FOR_REFINED_DECOMPOSITION
|
||||
) # 5
|
||||
|
||||
AGENT_EXPLORATORY_SEARCH_RESULTS = int(
|
||||
os.environ.get("AGENT_EXPLORATORY_SEARCH_RESULTS")
|
||||
or AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS
|
||||
) # 5
|
||||
|
||||
AGENT_MIN_ORIG_QUESTION_DOCS = int(
|
||||
os.environ.get("AGENT_MIN_ORIG_QUESTION_DOCS")
|
||||
or AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS
|
||||
) # 3
|
||||
|
||||
AGENT_MAX_ANSWER_CONTEXT_DOCS = int(
|
||||
os.environ.get("AGENT_MAX_ANSWER_CONTEXT_DOCS")
|
||||
or AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
|
||||
os.environ.get("AGENT_MAX_STATIC_HISTORY_WORD_LENGTH")
|
||||
or AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH
|
||||
) # 2000
|
||||
|
||||
AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER = int(
|
||||
os.environ.get("AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER")
|
||||
or AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
|
||||
) # 25
|
||||
|
||||
AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER = int(
|
||||
os.environ.get("AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER")
|
||||
or AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
|
||||
) # 35
|
||||
#####
|
||||
# Agent Configs
|
||||
#####
|
||||
|
||||
|
||||
AGENT_RETRIEVAL_STATS = (
|
||||
@@ -164,173 +77,4 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
|
||||
or AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH
|
||||
) # 2000
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = 10 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 4 # in seconds
|
||||
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 3 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 3 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 1 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = 3 # in seconds
|
||||
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
|
||||
GRAPH_VERSION_NAME: str = "a"
|
||||
|
||||
@@ -169,11 +169,6 @@ POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
POSTGRES_API_SERVER_POOL_OVERFLOW = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10
|
||||
)
|
||||
|
||||
# defaults to False
|
||||
# generally should only be used for
|
||||
POSTGRES_USE_NULL_POOL = os.environ.get("POSTGRES_USE_NULL_POOL", "").lower() == "true"
|
||||
|
||||
# defaults to False
|
||||
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
|
||||
|
||||
@@ -626,8 +621,6 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH")
|
||||
|
||||
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
|
||||
|
||||
# Set to true to mock LLM responses for testing purposes
|
||||
|
||||
@@ -125,7 +125,6 @@ class DocumentSource(str, Enum):
|
||||
GMAIL = "gmail"
|
||||
REQUESTTRACKER = "requesttracker"
|
||||
GITHUB = "github"
|
||||
GITBOOK = "gitbook"
|
||||
GITLAB = "gitlab"
|
||||
GURU = "guru"
|
||||
BOOKSTACK = "bookstack"
|
||||
@@ -165,9 +164,6 @@ class DocumentSource(str, Enum):
|
||||
EGNYTE = "egnyte"
|
||||
AIRTABLE = "airtable"
|
||||
|
||||
# Special case just for integration tests
|
||||
MOCK_CONNECTOR = "mock_connector"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
|
||||
@@ -246,7 +242,6 @@ class FileOrigin(str, Enum):
|
||||
CHAT_IMAGE_GEN = "chat_image_gen"
|
||||
CONNECTOR = "connector"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
INDEXING_CHECKPOINT = "indexing_checkpoint"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
@@ -278,7 +273,6 @@ class OnyxCeleryQueues:
|
||||
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
LLM_MODEL_UPDATE = "llm_model_update"
|
||||
CHECKPOINT_CLEANUP = "checkpoint_cleanup"
|
||||
|
||||
# Heavy queue
|
||||
CONNECTOR_PRUNING = "connector_pruning"
|
||||
@@ -298,7 +292,6 @@ class OnyxRedisLocks:
|
||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
|
||||
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
|
||||
CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK = "da_lock:check_checkpoint_cleanup_beat"
|
||||
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_doc_permissions_sync_beat"
|
||||
)
|
||||
@@ -374,10 +367,6 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
|
||||
# Connector checkpoint cleanup
|
||||
CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup"
|
||||
CLEANUP_CHECKPOINT = "cleanup_checkpoint"
|
||||
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
|
||||
|
||||
|
||||
@@ -245,7 +245,7 @@ class AirtableConnector(LoadConnector):
|
||||
return [(" ".join(combined) if combined else str(field_info), default_link)]
|
||||
|
||||
if isinstance(field_info, list):
|
||||
return [(str(item), default_link) for item in field_info]
|
||||
return [(item, default_link) for item in field_info]
|
||||
|
||||
return [(str(field_info), default_link)]
|
||||
|
||||
@@ -268,7 +268,7 @@ class AirtableConnector(LoadConnector):
|
||||
table_id: str,
|
||||
view_id: str | None,
|
||||
record_id: str,
|
||||
) -> tuple[list[Section], dict[str, str | list[str]]]:
|
||||
) -> tuple[list[Section], dict[str, Any]]:
|
||||
"""
|
||||
Process a single Airtable field and return sections or metadata.
|
||||
|
||||
@@ -342,7 +342,7 @@ class AirtableConnector(LoadConnector):
|
||||
record_id = record["id"]
|
||||
fields = record["fields"]
|
||||
sections: list[Section] = []
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
# Get primary field value if it exists
|
||||
primary_field_value = (
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -20,139 +15,48 @@ logger = setup_logger()
|
||||
TimeRange = tuple[datetime, datetime]
|
||||
|
||||
|
||||
class CheckpointOutputWrapper:
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format.
|
||||
The connector format is easier for the connector implementor (e.g. it enforces exactly
|
||||
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
|
||||
formats.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.next_checkpoint: ConnectorCheckpoint | None = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
checkpoint_connector_generator: CheckpointOutput,
|
||||
) -> Generator[
|
||||
tuple[Document | None, ConnectorFailure | None, ConnectorCheckpoint | None],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
# grabs the final return value and stores it in the `next_checkpoint` variable
|
||||
def _inner_wrapper(
|
||||
checkpoint_connector_generator: CheckpointOutput,
|
||||
) -> CheckpointOutput:
|
||||
self.next_checkpoint = yield from checkpoint_connector_generator
|
||||
return self.next_checkpoint # not used
|
||||
|
||||
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
|
||||
if isinstance(document_or_failure, Document):
|
||||
yield document_or_failure, None, None
|
||||
elif isinstance(document_or_failure, ConnectorFailure):
|
||||
yield None, document_or_failure, None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid document_or_failure type: {type(document_or_failure)}"
|
||||
)
|
||||
|
||||
if self.next_checkpoint is None:
|
||||
raise RuntimeError(
|
||||
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
|
||||
)
|
||||
|
||||
yield None, None, self.next_checkpoint
|
||||
|
||||
|
||||
class ConnectorRunner:
|
||||
"""
|
||||
Handles:
|
||||
- Batching
|
||||
- Additional exception logging
|
||||
- Combining different connector types to a single interface
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
batch_size: int,
|
||||
time_range: TimeRange | None = None,
|
||||
fail_loudly: bool = False,
|
||||
):
|
||||
self.connector = connector
|
||||
self.time_range = time_range
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.doc_batch: list[Document] = []
|
||||
if isinstance(self.connector, PollConnector):
|
||||
if time_range is None:
|
||||
raise ValueError("time_range is required for PollConnector")
|
||||
|
||||
def run(
|
||||
self, checkpoint: ConnectorCheckpoint
|
||||
) -> Generator[
|
||||
tuple[
|
||||
list[Document] | None, ConnectorFailure | None, ConnectorCheckpoint | None
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
self.doc_batch_generator = self.connector.poll_source(
|
||||
time_range[0].timestamp(), time_range[1].timestamp()
|
||||
)
|
||||
|
||||
elif isinstance(self.connector, LoadConnector):
|
||||
if time_range and fail_loudly:
|
||||
raise ValueError(
|
||||
"time_range specified, but passed in connector is not a PollConnector"
|
||||
)
|
||||
|
||||
self.doc_batch_generator = self.connector.load_from_state()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
|
||||
|
||||
def run(self) -> GenerateDocumentsOutput:
|
||||
"""Adds additional exception logging to the connector."""
|
||||
try:
|
||||
if isinstance(self.connector, CheckpointConnector):
|
||||
if self.time_range is None:
|
||||
raise ValueError("time_range is required for CheckpointConnector")
|
||||
start = time.monotonic()
|
||||
for batch in self.doc_batch_generator:
|
||||
# to know how long connector is taking
|
||||
logger.debug(
|
||||
f"Connector took {time.monotonic() - start} seconds to build a batch."
|
||||
)
|
||||
|
||||
yield batch
|
||||
|
||||
start = time.monotonic()
|
||||
checkpoint_connector_generator = self.connector.load_from_checkpoint(
|
||||
start=self.time_range[0].timestamp(),
|
||||
end=self.time_range[1].timestamp(),
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
next_checkpoint: ConnectorCheckpoint | None = None
|
||||
# this is guaranteed to always run at least once with next_checkpoint being non-None
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper()(
|
||||
checkpoint_connector_generator
|
||||
):
|
||||
if document is not None:
|
||||
self.doc_batch.append(document)
|
||||
|
||||
if failure is not None:
|
||||
yield None, failure, None
|
||||
|
||||
if len(self.doc_batch) >= self.batch_size:
|
||||
yield self.doc_batch, None, None
|
||||
self.doc_batch = []
|
||||
|
||||
# yield remaining documents
|
||||
if len(self.doc_batch) > 0:
|
||||
yield self.doc_batch, None, None
|
||||
self.doc_batch = []
|
||||
|
||||
yield None, None, next_checkpoint
|
||||
|
||||
logger.debug(
|
||||
f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint."
|
||||
)
|
||||
|
||||
else:
|
||||
finished_checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
finished_checkpoint.has_more = False
|
||||
|
||||
if isinstance(self.connector, PollConnector):
|
||||
if self.time_range is None:
|
||||
raise ValueError("time_range is required for PollConnector")
|
||||
|
||||
for document_batch in self.connector.poll_source(
|
||||
start=self.time_range[0].timestamp(),
|
||||
end=self.time_range[1].timestamp(),
|
||||
):
|
||||
yield document_batch, None, None
|
||||
|
||||
yield None, None, finished_checkpoint
|
||||
elif isinstance(self.connector, LoadConnector):
|
||||
for document_batch in self.connector.load_from_state():
|
||||
yield document_batch, None, None
|
||||
|
||||
yield None, None, finished_checkpoint
|
||||
else:
|
||||
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
|
||||
except Exception:
|
||||
exc_type, _, exc_traceback = sys.exc_info()
|
||||
|
||||
@@ -172,6 +76,6 @@ class ConnectorRunner:
|
||||
)
|
||||
logger.error(
|
||||
f"Error in connector. type: {exc_type};\n"
|
||||
f"local_vars below -> \n{local_vars_str[:1024]}"
|
||||
f"local_vars below -> \n{local_vars_str}"
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -20,7 +20,6 @@ from onyx.connectors.egnyte.connector import EgnyteConnector
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.fireflies.connector import FirefliesConnector
|
||||
from onyx.connectors.freshdesk.connector import FreshdeskConnector
|
||||
from onyx.connectors.gitbook.connector import GitbookConnector
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.connectors.gitlab.connector import GitlabConnector
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
@@ -30,14 +29,12 @@ from onyx.connectors.google_site.connector import GoogleSitesConnector
|
||||
from onyx.connectors.guru.connector import GuruConnector
|
||||
from onyx.connectors.hubspot.connector import HubSpotConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import EventConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.linear.connector import LinearConnector
|
||||
from onyx.connectors.loopio.connector import LoopioConnector
|
||||
from onyx.connectors.mediawiki.wiki import MediaWikiConnector
|
||||
from onyx.connectors.mock_connector.connector import MockConnector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.connectors.notion.connector import NotionConnector
|
||||
from onyx.connectors.onyx_jira.connector import JiraConnector
|
||||
@@ -45,7 +42,7 @@ from onyx.connectors.productboard.connector import ProductboardConnector
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.connectors.slab.connector import SlabConnector
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.connectors.slack.connector import SlackPollConnector
|
||||
from onyx.connectors.teams.connector import TeamsConnector
|
||||
from onyx.connectors.web.connector import WebConnector
|
||||
from onyx.connectors.wikipedia.connector import WikipediaConnector
|
||||
@@ -68,13 +65,12 @@ def identify_connector_class(
|
||||
DocumentSource.WEB: WebConnector,
|
||||
DocumentSource.FILE: LocalFileConnector,
|
||||
DocumentSource.SLACK: {
|
||||
InputType.POLL: SlackConnector,
|
||||
InputType.SLIM_RETRIEVAL: SlackConnector,
|
||||
InputType.POLL: SlackPollConnector,
|
||||
InputType.SLIM_RETRIEVAL: SlackPollConnector,
|
||||
},
|
||||
DocumentSource.GITHUB: GithubConnector,
|
||||
DocumentSource.GMAIL: GmailConnector,
|
||||
DocumentSource.GITLAB: GitlabConnector,
|
||||
DocumentSource.GITBOOK: GitbookConnector,
|
||||
DocumentSource.GOOGLE_DRIVE: GoogleDriveConnector,
|
||||
DocumentSource.BOOKSTACK: BookstackConnector,
|
||||
DocumentSource.CONFLUENCE: ConfluenceConnector,
|
||||
@@ -111,8 +107,6 @@ def identify_connector_class(
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
# just for integration tests
|
||||
DocumentSource.MOCK_CONNECTOR: MockConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
@@ -129,23 +123,10 @@ def identify_connector_class(
|
||||
|
||||
if any(
|
||||
[
|
||||
(
|
||||
input_type == InputType.LOAD_STATE
|
||||
and not issubclass(connector, LoadConnector)
|
||||
),
|
||||
(
|
||||
input_type == InputType.POLL
|
||||
# either poll or checkpoint works for this, in the future
|
||||
# all connectors should be checkpoint connectors
|
||||
and (
|
||||
not issubclass(connector, PollConnector)
|
||||
and not issubclass(connector, CheckpointConnector)
|
||||
)
|
||||
),
|
||||
(
|
||||
input_type == InputType.EVENT
|
||||
and not issubclass(connector, EventConnector)
|
||||
),
|
||||
input_type == InputType.LOAD_STATE
|
||||
and not issubclass(connector, LoadConnector),
|
||||
input_type == InputType.POLL and not issubclass(connector, PollConnector),
|
||||
input_type == InputType.EVENT and not issubclass(connector, EventConnector),
|
||||
]
|
||||
):
|
||||
raise ConnectorMissingException(
|
||||
|
||||
@@ -1,279 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
GITBOOK_API_BASE = "https://api.gitbook.com/v1/"
|
||||
|
||||
|
||||
class GitbookApiClient:
|
||||
def __init__(self, access_token: str) -> None:
|
||||
self.access_token = access_token
|
||||
|
||||
def get(self, endpoint: str, params: dict[str, Any] | None = None) -> Any:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = urljoin(GITBOOK_API_BASE, endpoint.lstrip("/"))
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def get_page_content(self, space_id: str, page_id: str) -> dict[str, Any]:
|
||||
return self.get(f"/spaces/{space_id}/content/page/{page_id}")
|
||||
|
||||
|
||||
def _extract_text_from_document(document: dict[str, Any]) -> str:
|
||||
"""Extract text content from GitBook document structure by parsing the document nodes
|
||||
into markdown format."""
|
||||
|
||||
def parse_leaf(leaf: dict[str, Any]) -> str:
|
||||
text = leaf.get("text", "")
|
||||
leaf.get("marks", [])
|
||||
return text
|
||||
|
||||
def parse_text_node(node: dict[str, Any]) -> str:
|
||||
text = ""
|
||||
for leaf in node.get("leaves", []):
|
||||
text += parse_leaf(leaf)
|
||||
return text
|
||||
|
||||
def parse_block_node(node: dict[str, Any]) -> str:
|
||||
block_type = node.get("type", "")
|
||||
result = ""
|
||||
|
||||
if block_type == "heading-1":
|
||||
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
|
||||
result = f"# {text}\n\n"
|
||||
|
||||
elif block_type == "heading-2":
|
||||
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
|
||||
result = f"## {text}\n\n"
|
||||
|
||||
elif block_type == "heading-3":
|
||||
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
|
||||
result = f"### {text}\n\n"
|
||||
|
||||
elif block_type == "heading-4":
|
||||
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
|
||||
result = f"#### {text}\n\n"
|
||||
|
||||
elif block_type == "heading-5":
|
||||
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
|
||||
result = f"##### {text}\n\n"
|
||||
|
||||
elif block_type == "heading-6":
|
||||
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
|
||||
result = f"###### {text}\n\n"
|
||||
|
||||
elif block_type == "list-unordered":
|
||||
for list_item in node.get("nodes", []):
|
||||
paragraph = list_item.get("nodes", [])[0]
|
||||
text = "".join(parse_text_node(n) for n in paragraph.get("nodes", []))
|
||||
result += f"* {text}\n"
|
||||
result += "\n"
|
||||
|
||||
elif block_type == "paragraph":
|
||||
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
|
||||
result = f"{text}\n\n"
|
||||
|
||||
elif block_type == "list-tasks":
|
||||
for task_item in node.get("nodes", []):
|
||||
checked = task_item.get("data", {}).get("checked", False)
|
||||
paragraph = task_item.get("nodes", [])[0]
|
||||
text = "".join(parse_text_node(n) for n in paragraph.get("nodes", []))
|
||||
checkbox = "[x]" if checked else "[ ]"
|
||||
result += f"- {checkbox} {text}\n"
|
||||
result += "\n"
|
||||
|
||||
elif block_type == "code":
|
||||
for code_line in node.get("nodes", []):
|
||||
if code_line.get("type") == "code-line":
|
||||
text = "".join(
|
||||
parse_text_node(n) for n in code_line.get("nodes", [])
|
||||
)
|
||||
result += f"{text}\n"
|
||||
result += "\n"
|
||||
|
||||
elif block_type == "blockquote":
|
||||
for quote_node in node.get("nodes", []):
|
||||
if quote_node.get("type") == "paragraph":
|
||||
text = "".join(
|
||||
parse_text_node(n) for n in quote_node.get("nodes", [])
|
||||
)
|
||||
result += f"> {text}\n"
|
||||
result += "\n"
|
||||
|
||||
elif block_type == "table":
|
||||
records = node.get("data", {}).get("records", {})
|
||||
definition = node.get("data", {}).get("definition", {})
|
||||
view = node.get("data", {}).get("view", {})
|
||||
|
||||
columns = view.get("columns", [])
|
||||
|
||||
header_cells = []
|
||||
for col_id in columns:
|
||||
col_def = definition.get(col_id, {})
|
||||
header_cells.append(col_def.get("title", ""))
|
||||
|
||||
result = "| " + " | ".join(header_cells) + " |\n"
|
||||
result += "|" + "---|" * len(header_cells) + "\n"
|
||||
|
||||
sorted_records = sorted(
|
||||
records.items(), key=lambda x: x[1].get("orderIndex", "")
|
||||
)
|
||||
|
||||
for record_id, record_data in sorted_records:
|
||||
values = record_data.get("values", {})
|
||||
row_cells = []
|
||||
for col_id in columns:
|
||||
fragment_id = values.get(col_id, "")
|
||||
fragment_text = ""
|
||||
for fragment in node.get("fragments", []):
|
||||
if fragment.get("fragment") == fragment_id:
|
||||
for frag_node in fragment.get("nodes", []):
|
||||
if frag_node.get("type") == "paragraph":
|
||||
fragment_text = "".join(
|
||||
parse_text_node(n)
|
||||
for n in frag_node.get("nodes", [])
|
||||
)
|
||||
break
|
||||
row_cells.append(fragment_text)
|
||||
result += "| " + " | ".join(row_cells) + " |\n"
|
||||
|
||||
result += "\n"
|
||||
return result
|
||||
|
||||
if not document or "document" not in document:
|
||||
return ""
|
||||
|
||||
markdown = ""
|
||||
nodes = document["document"].get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
markdown += parse_block_node(node)
|
||||
|
||||
return markdown
|
||||
|
||||
|
||||
def _convert_page_to_document(
|
||||
client: GitbookApiClient, space_id: str, page: dict[str, Any]
|
||||
) -> Document:
|
||||
page_id = page["id"]
|
||||
page_content = client.get_page_content(space_id, page_id)
|
||||
|
||||
return Document(
|
||||
id=f"gitbook-{space_id}-{page_id}",
|
||||
sections=[
|
||||
Section(
|
||||
link=page.get("urls", {}).get("app", ""),
|
||||
text=_extract_text_from_document(page_content),
|
||||
)
|
||||
],
|
||||
source=DocumentSource.GITBOOK,
|
||||
semantic_identifier=page.get("title", ""),
|
||||
doc_updated_at=datetime.fromisoformat(page["updatedAt"]).replace(
|
||||
tzinfo=timezone.utc
|
||||
),
|
||||
metadata={
|
||||
"path": page.get("path", ""),
|
||||
"type": page.get("type", ""),
|
||||
"kind": page.get("kind", ""),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class GitbookConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
space_id: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.space_id = space_id
|
||||
self.batch_size = batch_size
|
||||
self.access_token: str | None = None
|
||||
self.client: GitbookApiClient | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
access_token = credentials.get("gitbook_api_key")
|
||||
if not access_token:
|
||||
raise ConnectorMissingCredentialError("GitBook access token")
|
||||
self.access_token = access_token
|
||||
self.client = GitbookApiClient(access_token)
|
||||
|
||||
def _fetch_all_pages(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
if not self.client:
|
||||
raise ConnectorMissingCredentialError("GitBook")
|
||||
|
||||
try:
|
||||
content = self.client.get(f"/spaces/{self.space_id}/content")
|
||||
pages = content.get("pages", [])
|
||||
|
||||
current_batch: list[Document] = []
|
||||
for page in pages:
|
||||
updated_at = datetime.fromisoformat(page["updatedAt"])
|
||||
|
||||
if start and updated_at < start:
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
return
|
||||
if end and updated_at > end:
|
||||
continue
|
||||
|
||||
current_batch.append(
|
||||
_convert_page_to_document(self.client, self.space_id, page)
|
||||
)
|
||||
|
||||
if len(current_batch) >= self.batch_size:
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Error fetching GitBook content: {str(e)}")
|
||||
raise
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_all_pages()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
return self._fetch_all_pages(start_datetime, end_datetime)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
connector = GitbookConnector(
|
||||
space_id=os.environ["GITBOOK_SPACE_ID"],
|
||||
)
|
||||
connector.load_credentials({"gitbook_api_key": os.environ["GITBOOK_API_KEY"]})
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
@@ -302,7 +302,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if e.status_code == 401:
|
||||
# fail gracefully, let the other impersonations continue
|
||||
# one user without access shouldn't block the entire connector
|
||||
logger.warning(
|
||||
logger.exception(
|
||||
f"User '{user_email}' does not have access to the drive APIs."
|
||||
)
|
||||
return
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -17,7 +14,6 @@ SecondsSinceUnixEpoch = float
|
||||
|
||||
GenerateDocumentsOutput = Iterator[list[Document]]
|
||||
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
|
||||
CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]
|
||||
|
||||
|
||||
class BaseConnector(abc.ABC):
|
||||
@@ -109,33 +105,3 @@ class EventConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def handle_event(self, event: Any) -> GenerateDocumentsOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CheckpointConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
"""Yields back documents or failures. Final return is the new checkpoint.
|
||||
|
||||
Final return can be access via either:
|
||||
|
||||
```
|
||||
try:
|
||||
for document_or_failure in connector.load_from_checkpoint(start, end, checkpoint):
|
||||
print(document_or_failure)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value # Extracting the return value
|
||||
print(checkpoint)
|
||||
```
|
||||
|
||||
OR
|
||||
|
||||
```
|
||||
checkpoint = yield from connector.load_from_checkpoint(start, end, checkpoint)
|
||||
```
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class SingleConnectorYield(BaseModel):
|
||||
documents: list[Document]
|
||||
checkpoint: ConnectorCheckpoint
|
||||
failures: list[ConnectorFailure]
|
||||
unhandled_exception: str | None = None
|
||||
|
||||
|
||||
class MockConnector(CheckpointConnector):
|
||||
def __init__(
|
||||
self,
|
||||
mock_server_host: str,
|
||||
mock_server_port: int,
|
||||
) -> None:
|
||||
self.mock_server_host = mock_server_host
|
||||
self.mock_server_port = mock_server_port
|
||||
self.client = httpx.Client(timeout=30.0)
|
||||
|
||||
self.connector_yields: list[SingleConnectorYield] | None = None
|
||||
self.current_yield_index: int = 0
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
response = self.client.get(self._get_mock_server_url("get-documents"))
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
self.connector_yields = [
|
||||
SingleConnectorYield(**yield_data) for yield_data in data
|
||||
]
|
||||
return None
|
||||
|
||||
def _get_mock_server_url(self, endpoint: str) -> str:
|
||||
return f"http://{self.mock_server_host}:{self.mock_server_port}/{endpoint}"
|
||||
|
||||
def _save_checkpoint(self, checkpoint: ConnectorCheckpoint) -> None:
|
||||
response = self.client.post(
|
||||
self._get_mock_server_url("add-checkpoint"),
|
||||
json=checkpoint.model_dump(mode="json"),
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
if self.connector_yields is None:
|
||||
raise ValueError("No connector yields configured")
|
||||
|
||||
# Save the checkpoint to the mock server
|
||||
self._save_checkpoint(checkpoint)
|
||||
|
||||
yield_index = self.current_yield_index
|
||||
self.current_yield_index += 1
|
||||
current_yield = self.connector_yields[yield_index]
|
||||
|
||||
# If the current yield has an unhandled exception, raise it
|
||||
# This is used to simulate an unhandled failure in the connector.
|
||||
if current_yield.unhandled_exception:
|
||||
raise RuntimeError(current_yield.unhandled_exception)
|
||||
|
||||
# yield all documents
|
||||
for document in current_yield.documents:
|
||||
yield document
|
||||
|
||||
for failure in current_yield.failures:
|
||||
yield failure
|
||||
|
||||
return current_yield.checkpoint
|
||||
@@ -3,7 +3,6 @@ from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
@@ -188,48 +187,36 @@ class SlimDocument(BaseModel):
|
||||
perm_sync_data: Any | None = None
|
||||
|
||||
|
||||
class IndexAttemptMetadata(BaseModel):
|
||||
batch_num: int | None = None
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
|
||||
|
||||
class ConnectorCheckpoint(BaseModel):
|
||||
# TODO: maybe move this to something disk-based to handle extremely large checkpoints?
|
||||
checkpoint_content: dict
|
||||
has_more: bool
|
||||
class DocumentErrorSummary(BaseModel):
|
||||
id: str
|
||||
semantic_id: str
|
||||
section_link: str | None
|
||||
|
||||
@classmethod
|
||||
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
|
||||
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
|
||||
def from_document(cls, doc: Document) -> "DocumentErrorSummary":
|
||||
section_link = doc.sections[0].link if len(doc.sections) > 0 else None
|
||||
return cls(
|
||||
id=doc.id, semantic_id=doc.semantic_identifier, section_link=section_link
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "DocumentErrorSummary":
|
||||
return cls(
|
||||
id=str(data.get("id")),
|
||||
semantic_id=str(data.get("semantic_id")),
|
||||
section_link=str(data.get("section_link")),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, str | None]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"semantic_id": self.semantic_id,
|
||||
"section_link": self.section_link,
|
||||
}
|
||||
|
||||
|
||||
class DocumentFailure(BaseModel):
|
||||
document_id: str
|
||||
document_link: str | None = None
|
||||
|
||||
|
||||
class EntityFailure(BaseModel):
|
||||
entity_id: str
|
||||
missed_time_range: tuple[datetime, datetime] | None = None
|
||||
|
||||
|
||||
class ConnectorFailure(BaseModel):
|
||||
failed_document: DocumentFailure | None = None
|
||||
failed_entity: EntityFailure | None = None
|
||||
failure_message: str
|
||||
exception: Exception | None = None
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_failed_fields(cls, values: dict) -> dict:
|
||||
failed_document = values.get("failed_document")
|
||||
failed_entity = values.get("failed_entity")
|
||||
if (failed_document is None and failed_entity is None) or (
|
||||
failed_document is not None and failed_entity is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"Exactly one of 'failed_document' or 'failed_entity' must be specified."
|
||||
)
|
||||
return values
|
||||
class IndexAttemptMetadata(BaseModel):
|
||||
batch_num: int | None = None
|
||||
num_exceptions: int = 0
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
|
||||
@@ -1,16 +1,10 @@
|
||||
import contextvars
|
||||
import copy
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypedDict
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
@@ -18,18 +12,14 @@ from slack_sdk.errors import SlackApiError
|
||||
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
@@ -43,8 +33,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SLACK_LIMIT = 900
|
||||
|
||||
|
||||
ChannelType = dict[str, Any]
|
||||
MessageType = dict[str, Any]
|
||||
@@ -52,13 +40,6 @@ MessageType = dict[str, Any]
|
||||
ThreadType = list[MessageType]
|
||||
|
||||
|
||||
class SlackCheckpointContent(TypedDict):
|
||||
channel_ids: list[str]
|
||||
channel_completion_map: dict[str, str]
|
||||
current_channel: ChannelType | None
|
||||
seen_thread_ts: list[str]
|
||||
|
||||
|
||||
def _collect_paginated_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool,
|
||||
@@ -159,10 +140,6 @@ def get_latest_message_time(thread: ThreadType) -> datetime:
|
||||
return datetime.fromtimestamp(max_ts, tz=timezone.utc)
|
||||
|
||||
|
||||
def _build_doc_id(channel_id: str, thread_ts: str) -> str:
|
||||
return f"{channel_id}__{thread_ts}"
|
||||
|
||||
|
||||
def thread_to_doc(
|
||||
channel: ChannelType,
|
||||
thread: ThreadType,
|
||||
@@ -205,7 +182,7 @@ def thread_to_doc(
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
|
||||
id=f"{channel_id}__{thread[0]['ts']}",
|
||||
sections=[
|
||||
Section(
|
||||
link=get_message_link(event=m, client=client, channel_id=channel_id),
|
||||
@@ -290,97 +267,64 @@ def filter_channels(
|
||||
]
|
||||
|
||||
|
||||
def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType:
|
||||
"""Get a channel by its ID.
|
||||
|
||||
Args:
|
||||
client: The Slack WebClient instance
|
||||
channel_id: The ID of the channel to fetch
|
||||
|
||||
Returns:
|
||||
The channel information
|
||||
|
||||
Raises:
|
||||
SlackApiError: If the channel cannot be fetched
|
||||
"""
|
||||
response = make_slack_api_call_w_retries(
|
||||
client.conversations_info,
|
||||
channel=channel_id,
|
||||
)
|
||||
return cast(ChannelType, response["channel"])
|
||||
|
||||
|
||||
def _get_messages(
|
||||
channel: ChannelType,
|
||||
def _get_all_docs(
|
||||
client: WebClient,
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
) -> tuple[list[MessageType], bool]:
|
||||
"""Slack goes from newest to oldest."""
|
||||
|
||||
# have to be in the channel in order to read messages
|
||||
if not channel["is_member"]:
|
||||
make_slack_api_call_w_retries(
|
||||
client.conversations_join,
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
logger.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
response = make_slack_api_call_w_retries(
|
||||
client.conversations_history,
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
latest=latest,
|
||||
limit=_SLACK_LIMIT,
|
||||
)
|
||||
response.validate()
|
||||
|
||||
messages = cast(list[MessageType], response.get("messages", []))
|
||||
|
||||
cursor = cast(dict[str, Any], response.get("response_metadata", {})).get(
|
||||
"next_cursor", ""
|
||||
)
|
||||
has_more = bool(cursor)
|
||||
return messages, has_more
|
||||
|
||||
|
||||
def _message_to_doc(
|
||||
message: MessageType,
|
||||
client: WebClient,
|
||||
channel: ChannelType,
|
||||
slack_cleaner: SlackTextCleaner,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
seen_thread_ts: set[str],
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> Document | None:
|
||||
filtered_thread: ThreadType | None = None
|
||||
thread_ts = message.get("thread_ts")
|
||||
if thread_ts:
|
||||
# skip threads we've already seen, since we've already processed all
|
||||
# messages in that thread
|
||||
if thread_ts in seen_thread_ts:
|
||||
return None
|
||||
) -> Generator[Document, None, None]:
|
||||
"""Get all documents in the workspace, channel by channel"""
|
||||
slack_cleaner = SlackTextCleaner(client=client)
|
||||
|
||||
thread = get_thread(
|
||||
client=client, channel_id=channel["id"], thread_id=thread_ts
|
||||
)
|
||||
filtered_thread = [
|
||||
message for message in thread if not msg_filter_func(message)
|
||||
]
|
||||
elif not msg_filter_func(message):
|
||||
filtered_thread = [message]
|
||||
# Cache to prevent refetching via API since users
|
||||
user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||
|
||||
if filtered_thread:
|
||||
return thread_to_doc(
|
||||
channel=channel,
|
||||
thread=filtered_thread,
|
||||
slack_cleaner=slack_cleaner,
|
||||
client=client,
|
||||
user_cache=user_cache,
|
||||
all_channels = get_channels(client)
|
||||
filtered_channels = filter_channels(
|
||||
all_channels, channels, channel_name_regex_enabled
|
||||
)
|
||||
|
||||
for channel in filtered_channels:
|
||||
channel_docs = 0
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client, channel=channel, oldest=oldest, latest=latest
|
||||
)
|
||||
|
||||
return None
|
||||
seen_thread_ts: set[str] = set()
|
||||
for message_batch in channel_message_batches:
|
||||
for message in message_batch:
|
||||
filtered_thread: ThreadType | None = None
|
||||
thread_ts = message.get("thread_ts")
|
||||
if thread_ts:
|
||||
# skip threads we've already seen, since we've already processed all
|
||||
# messages in that thread
|
||||
if thread_ts in seen_thread_ts:
|
||||
continue
|
||||
seen_thread_ts.add(thread_ts)
|
||||
thread = get_thread(
|
||||
client=client, channel_id=channel["id"], thread_id=thread_ts
|
||||
)
|
||||
filtered_thread = [
|
||||
message for message in thread if not msg_filter_func(message)
|
||||
]
|
||||
elif not msg_filter_func(message):
|
||||
filtered_thread = [message]
|
||||
|
||||
if filtered_thread:
|
||||
channel_docs += 1
|
||||
yield thread_to_doc(
|
||||
channel=channel,
|
||||
thread=filtered_thread,
|
||||
slack_cleaner=slack_cleaner,
|
||||
client=client,
|
||||
user_cache=user_cache,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Pulled {channel_docs} documents from slack channel {channel['name']}"
|
||||
)
|
||||
|
||||
|
||||
def _get_all_doc_ids(
|
||||
@@ -424,7 +368,7 @@ def _get_all_doc_ids(
|
||||
for message_ts in message_ts_set:
|
||||
channel_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=_build_doc_id(channel_id=channel_id, thread_ts=message_ts),
|
||||
id=f"{channel_id}__{message_ts}",
|
||||
perm_sync_data={"channel_id": channel_id},
|
||||
)
|
||||
)
|
||||
@@ -432,51 +376,7 @@ def _get_all_doc_ids(
|
||||
yield channel_metadata_list
|
||||
|
||||
|
||||
def _process_message(
|
||||
message: MessageType,
|
||||
client: WebClient,
|
||||
channel: ChannelType,
|
||||
slack_cleaner: SlackTextCleaner,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
seen_thread_ts: set[str],
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> tuple[Document | None, str | None, ConnectorFailure | None]:
|
||||
thread_ts = message.get("thread_ts")
|
||||
try:
|
||||
# causes random failures for testing checkpointing / continue on failure
|
||||
# import random
|
||||
# if random.random() > 0.95:
|
||||
# raise RuntimeError("Random failure :P")
|
||||
|
||||
doc = _message_to_doc(
|
||||
message=message,
|
||||
client=client,
|
||||
channel=channel,
|
||||
slack_cleaner=slack_cleaner,
|
||||
user_cache=user_cache,
|
||||
seen_thread_ts=seen_thread_ts,
|
||||
msg_filter_func=msg_filter_func,
|
||||
)
|
||||
return (doc, thread_ts, None)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing message {message['ts']}")
|
||||
return (
|
||||
None,
|
||||
thread_ts,
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=_build_doc_id(
|
||||
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
|
||||
),
|
||||
document_link=get_message_link(message, client, channel["id"]),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
class SlackPollConnector(PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
channels: list[str] | None = None,
|
||||
@@ -490,14 +390,9 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
self.batch_size = batch_size
|
||||
self.client: WebClient | None = None
|
||||
|
||||
# just used for efficiency
|
||||
self.text_cleaner: SlackTextCleaner | None = None
|
||||
self.user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
bot_token = credentials["slack_bot_token"]
|
||||
self.client = WebClient(token=bot_token)
|
||||
self.text_cleaner = SlackTextCleaner(client=self.client)
|
||||
return None
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
@@ -516,155 +411,30 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> CheckpointOutput:
|
||||
"""Rough outline:
|
||||
|
||||
Step 1: Get all channels, yield back Checkpoint.
|
||||
Step 2: Loop through each channel. For each channel:
|
||||
Step 2.1: Get messages within the time range.
|
||||
Step 2.2: Process messages in parallel, yield back docs.
|
||||
Step 2.3: Update checkpoint with new_latest, seen_thread_ts, and current_channel.
|
||||
Slack returns messages from newest to oldest, so we need to keep track of
|
||||
the latest message we've seen in each channel.
|
||||
Step 2.4: If there are no more messages in the channel, switch the current
|
||||
channel to the next channel.
|
||||
"""
|
||||
if self.client is None or self.text_cleaner is None:
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
checkpoint_content = cast(
|
||||
SlackCheckpointContent,
|
||||
(
|
||||
copy.deepcopy(checkpoint.checkpoint_content)
|
||||
or {
|
||||
"channel_ids": None,
|
||||
"channel_completion_map": {},
|
||||
"current_channel": None,
|
||||
"seen_thread_ts": [],
|
||||
}
|
||||
),
|
||||
)
|
||||
documents: list[Document] = []
|
||||
for document in _get_all_docs(
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
# NOTE: need to impute to `None` instead of using 0.0, since Slack will
|
||||
# throw an error if we use 0.0 on an account without infinite data
|
||||
# retention
|
||||
oldest=str(start) if start else None,
|
||||
latest=str(end),
|
||||
):
|
||||
documents.append(document)
|
||||
if len(documents) >= self.batch_size:
|
||||
yield documents
|
||||
documents = []
|
||||
|
||||
# if this is the very first time we've called this, need to
|
||||
# get all relevant channels and save them into the checkpoint
|
||||
if checkpoint_content["channel_ids"] is None:
|
||||
raw_channels = get_channels(self.client)
|
||||
filtered_channels = filter_channels(
|
||||
raw_channels, self.channels, self.channel_regex_enabled
|
||||
)
|
||||
if len(filtered_channels) == 0:
|
||||
return checkpoint
|
||||
|
||||
checkpoint_content["channel_ids"] = [c["id"] for c in filtered_channels]
|
||||
checkpoint_content["current_channel"] = filtered_channels[0]
|
||||
checkpoint = ConnectorCheckpoint(
|
||||
checkpoint_content=checkpoint_content, # type: ignore
|
||||
has_more=True,
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
final_channel_ids = checkpoint_content["channel_ids"]
|
||||
channel = checkpoint_content["current_channel"]
|
||||
if channel is None:
|
||||
raise ValueError("current_channel key not found in checkpoint")
|
||||
|
||||
channel_id = channel["id"]
|
||||
if channel_id not in final_channel_ids:
|
||||
raise ValueError(f"Channel {channel_id} not found in checkpoint")
|
||||
|
||||
oldest = str(start) if start else None
|
||||
latest = checkpoint_content["channel_completion_map"].get(channel_id, str(end))
|
||||
seen_thread_ts = set(checkpoint_content["seen_thread_ts"])
|
||||
try:
|
||||
logger.debug(
|
||||
f"Getting messages for channel {channel} within range {oldest} - {latest}"
|
||||
)
|
||||
message_batch, has_more_in_channel = _get_messages(
|
||||
channel, self.client, oldest, latest
|
||||
)
|
||||
new_latest = message_batch[-1]["ts"] if message_batch else latest
|
||||
|
||||
# Process messages in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
futures: list[Future] = []
|
||||
for message in message_batch:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
current_context = contextvars.copy_context()
|
||||
futures.append(
|
||||
executor.submit(
|
||||
current_context.run,
|
||||
_process_message,
|
||||
message=message,
|
||||
client=self.client,
|
||||
channel=channel,
|
||||
slack_cleaner=self.text_cleaner,
|
||||
user_cache=self.user_cache,
|
||||
seen_thread_ts=seen_thread_ts,
|
||||
)
|
||||
)
|
||||
|
||||
for future in as_completed(futures):
|
||||
doc, thread_ts, failures = future.result()
|
||||
if doc:
|
||||
# handle race conditions here since this is single
|
||||
# threaded. Multi-threaded _process_message reads from this
|
||||
# but since this is single threaded, we won't run into simul
|
||||
# writes. At worst, we can duplicate a thread, which will be
|
||||
# deduped later on.
|
||||
if thread_ts not in seen_thread_ts:
|
||||
yield doc
|
||||
|
||||
if thread_ts:
|
||||
seen_thread_ts.add(thread_ts)
|
||||
elif failures:
|
||||
for failure in failures:
|
||||
yield failure
|
||||
|
||||
checkpoint_content["seen_thread_ts"] = list(seen_thread_ts)
|
||||
checkpoint_content["channel_completion_map"][channel["id"]] = new_latest
|
||||
if has_more_in_channel:
|
||||
checkpoint_content["current_channel"] = channel
|
||||
else:
|
||||
new_channel_id = next(
|
||||
(
|
||||
channel_id
|
||||
for channel_id in final_channel_ids
|
||||
if channel_id
|
||||
not in checkpoint_content["channel_completion_map"]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if new_channel_id:
|
||||
new_channel = _get_channel_by_id(self.client, new_channel_id)
|
||||
checkpoint_content["current_channel"] = new_channel
|
||||
else:
|
||||
checkpoint_content["current_channel"] = None
|
||||
|
||||
checkpoint = ConnectorCheckpoint(
|
||||
checkpoint_content=checkpoint_content, # type: ignore
|
||||
has_more=checkpoint_content["current_channel"] is not None,
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing channel {channel['name']}")
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=channel["id"],
|
||||
missed_time_range=(
|
||||
datetime.fromtimestamp(start, tz=timezone.utc),
|
||||
datetime.fromtimestamp(end, tz=timezone.utc),
|
||||
),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
return checkpoint
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -672,7 +442,7 @@ if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
slack_channel = os.environ.get("SLACK_CHANNEL")
|
||||
connector = SlackConnector(
|
||||
connector = SlackPollConnector(
|
||||
channels=[slack_channel] if slack_channel else None,
|
||||
)
|
||||
connector.load_credentials({"slack_bot_token": os.environ["SLACK_BOT_TOKEN"]})
|
||||
@@ -680,17 +450,6 @@ if __name__ == "__main__":
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
|
||||
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
|
||||
document_batches = connector.poll_source(one_day_ago, current)
|
||||
|
||||
gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint)
|
||||
try:
|
||||
for document_or_failure in gen:
|
||||
if isinstance(document_or_failure, Document):
|
||||
print(document_or_failure)
|
||||
elif isinstance(document_or_failure, ConnectorFailure):
|
||||
print(document_or_failure)
|
||||
except StopIteration as e:
|
||||
checkpoint = e.value
|
||||
print("Next checkpoint:", checkpoint)
|
||||
|
||||
print("Next checkpoint:", checkpoint)
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -34,14 +34,9 @@ def get_message_link(
|
||||
) -> str:
|
||||
channel_id = channel_id or event["channel"]
|
||||
message_ts = event["ts"]
|
||||
message_ts_without_dot = message_ts.replace(".", "")
|
||||
thread_ts = event.get("thread_ts")
|
||||
base_url = get_base_url(client.token)
|
||||
|
||||
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (
|
||||
f"?thread_ts={thread_ts}" if thread_ts else ""
|
||||
)
|
||||
return link
|
||||
response = client.chat_getPermalink(channel=channel_id, message_ts=message_ts)
|
||||
permalink = response["permalink"]
|
||||
return permalink
|
||||
|
||||
|
||||
def _make_slack_api_call_paginated(
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
import os
|
||||
import tempfile
|
||||
import urllib.parse
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from zulip import Client
|
||||
|
||||
@@ -41,39 +36,8 @@ class ZulipConnector(LoadConnector, PollConnector):
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.realm_name = realm_name
|
||||
|
||||
# Clean and normalize the URL
|
||||
realm_url = realm_url.strip().lower()
|
||||
|
||||
# Remove any trailing slashes
|
||||
realm_url = realm_url.rstrip("/")
|
||||
|
||||
# Ensure the URL has a scheme
|
||||
if not realm_url.startswith(("http://", "https://")):
|
||||
realm_url = f"https://{realm_url}"
|
||||
|
||||
try:
|
||||
parsed = urllib.parse.urlparse(realm_url)
|
||||
|
||||
# Extract the base domain without any paths or ports
|
||||
netloc = parsed.netloc.split(":")[0] # Remove port if present
|
||||
|
||||
if not netloc:
|
||||
raise ValueError(
|
||||
f"Invalid realm URL format: {realm_url}. "
|
||||
f"URL must include a valid domain name."
|
||||
)
|
||||
|
||||
# Always use HTTPS for security
|
||||
self.base_url = f"https://{netloc}"
|
||||
self.client: Client | None = None
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to parse Zulip realm URL: {realm_url}. "
|
||||
f"Please provide a URL in the format: domain.com or https://domain.com. "
|
||||
f"Error: {str(e)}"
|
||||
)
|
||||
self.realm_url = realm_url if realm_url.endswith("/") else realm_url + "/"
|
||||
self.client: Client | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
contents = credentials["zuliprc_content"]
|
||||
@@ -91,17 +55,12 @@ class ZulipConnector(LoadConnector, PollConnector):
|
||||
return None
|
||||
|
||||
def _message_to_narrow_link(self, m: Message) -> str:
|
||||
try:
|
||||
stream_name = m.display_recipient # assume str
|
||||
stream_operand = encode_zulip_narrow_operand(f"{m.stream_id}-{stream_name}")
|
||||
topic_operand = encode_zulip_narrow_operand(m.subject)
|
||||
stream_name = m.display_recipient # assume str
|
||||
stream_operand = encode_zulip_narrow_operand(f"{m.stream_id}-{stream_name}")
|
||||
topic_operand = encode_zulip_narrow_operand(m.subject)
|
||||
|
||||
narrow_link = f"{self.base_url}#narrow/stream/{stream_operand}/topic/{topic_operand}/near/{m.id}"
|
||||
return narrow_link
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating Zulip message link: {e}")
|
||||
# Fallback to a basic link that at least includes the base URL
|
||||
return f"{self.base_url}#narrow/id/{m.id}"
|
||||
narrow_link = f"{self.realm_url}#narrow/stream/{stream_operand}/topic/{topic_operand}/near/{m.id}"
|
||||
return narrow_link
|
||||
|
||||
def _get_message_batch(self, anchor: str) -> Tuple[bool, List[Message]]:
|
||||
if self.client is None:
|
||||
@@ -124,40 +83,6 @@ class ZulipConnector(LoadConnector, PollConnector):
|
||||
def _message_to_doc(self, message: Message) -> Document:
|
||||
text = f"{message.sender_full_name}: {message.content}"
|
||||
|
||||
try:
|
||||
# Convert timestamps to UTC datetime objects
|
||||
post_time = datetime.fromtimestamp(message.timestamp, tz=timezone.utc)
|
||||
edit_time = (
|
||||
datetime.fromtimestamp(message.last_edit_timestamp, tz=timezone.utc)
|
||||
if message.last_edit_timestamp is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Use the most recent edit time if available, otherwise use post time
|
||||
doc_time = edit_time if edit_time is not None else post_time
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"Failed to parse timestamp for message {message.id}: {e}")
|
||||
post_time = None
|
||||
edit_time = None
|
||||
doc_time = None
|
||||
|
||||
metadata: Dict[str, Union[str, List[str]]] = {
|
||||
"stream_name": str(message.display_recipient),
|
||||
"topic": str(message.subject),
|
||||
"sender_name": str(message.sender_full_name),
|
||||
"sender_email": str(message.sender_email),
|
||||
"message_timestamp": str(message.timestamp),
|
||||
"message_id": str(message.id),
|
||||
"stream_id": str(message.stream_id),
|
||||
"has_reactions": str(len(message.reactions) > 0),
|
||||
"content_type": str(message.content_type or "text"),
|
||||
}
|
||||
|
||||
# Always include edit timestamp in metadata when available
|
||||
if edit_time is not None:
|
||||
metadata["edit_timestamp"] = str(message.last_edit_timestamp)
|
||||
|
||||
return Document(
|
||||
id=f"{message.stream_id}__{message.id}",
|
||||
sections=[
|
||||
@@ -167,9 +92,8 @@ class ZulipConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
],
|
||||
source=DocumentSource.ZULIP,
|
||||
semantic_identifier=f"{message.display_recipient} > {message.subject}",
|
||||
metadata=metadata,
|
||||
doc_updated_at=doc_time, # Use most recent edit time or post time
|
||||
semantic_identifier=message.display_recipient or message.subject,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
def _get_docs(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -20,7 +19,7 @@ class Message(BaseModel):
|
||||
sender_realm_str: str
|
||||
subject: str
|
||||
topic_links: Optional[List[Any]] = None
|
||||
last_edit_timestamp: Optional[int] = None
|
||||
last_edit_timestamp: Optional[int]
|
||||
edit_history: Any = None
|
||||
reactions: List[Any]
|
||||
submessages: List[Any]
|
||||
@@ -40,5 +39,5 @@ class GetMessagesResponse(BaseModel):
|
||||
found_oldest: Optional[bool] = None
|
||||
found_newest: Optional[bool] = None
|
||||
history_limited: Optional[bool] = None
|
||||
anchor: Optional[Union[str, int]] = None
|
||||
anchor: Optional[str] = None
|
||||
messages: List[Message] = Field(default_factory=list)
|
||||
|
||||
@@ -350,17 +350,13 @@ def delete_chat_session(
|
||||
user_id: UUID | None,
|
||||
chat_session_id: UUID,
|
||||
db_session: Session,
|
||||
include_deleted: bool = False,
|
||||
hard_delete: bool = HARD_DELETE_CHATS,
|
||||
) -> None:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
include_deleted=include_deleted,
|
||||
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
|
||||
if chat_session.deleted and not include_deleted:
|
||||
if chat_session.deleted:
|
||||
raise ValueError("Cannot delete an already deleted chat session")
|
||||
|
||||
if hard_delete:
|
||||
@@ -384,15 +380,7 @@ def delete_chat_sessions_older_than(days_old: int, db_session: Session) -> None:
|
||||
).fetchall()
|
||||
|
||||
for user_id, session_id in old_sessions:
|
||||
try:
|
||||
delete_chat_session(
|
||||
user_id, session_id, db_session, include_deleted=True, hard_delete=True
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
delete_chat_session(user_id, session_id, db_session, hard_delete=True)
|
||||
|
||||
|
||||
def get_chat_message(
|
||||
@@ -628,7 +616,7 @@ def create_new_chat_message(
|
||||
commit: bool = True,
|
||||
reserved_message_id: int | None = None,
|
||||
overridden_model: str | None = None,
|
||||
refined_answer_improvement: bool | None = None,
|
||||
refined_answer_improvement: bool = True,
|
||||
) -> ChatMessage:
|
||||
if reserved_message_id is not None:
|
||||
# Edit existing message
|
||||
@@ -905,18 +893,14 @@ def translate_db_sub_questions_to_server_objects(
|
||||
question=sub_question.sub_question,
|
||||
answer=sub_question.sub_answer,
|
||||
sub_queries=sub_queries,
|
||||
context_docs=get_retrieval_docs_from_search_docs(
|
||||
verified_docs, sort_by_score=False
|
||||
),
|
||||
context_docs=get_retrieval_docs_from_search_docs(verified_docs),
|
||||
)
|
||||
)
|
||||
return sub_questions
|
||||
|
||||
|
||||
def get_retrieval_docs_from_search_docs(
|
||||
search_docs: list[SearchDoc],
|
||||
remove_doc_content: bool = False,
|
||||
sort_by_score: bool = True,
|
||||
search_docs: list[SearchDoc], remove_doc_content: bool = False
|
||||
) -> RetrievalDocs:
|
||||
top_documents = [
|
||||
translate_db_search_doc_to_server_search_doc(
|
||||
@@ -924,8 +908,7 @@ def get_retrieval_docs_from_search_docs(
|
||||
)
|
||||
for db_doc in search_docs
|
||||
]
|
||||
if sort_by_score:
|
||||
top_documents = sorted(top_documents, key=lambda doc: doc.score, reverse=True) # type: ignore
|
||||
top_documents = sorted(top_documents, key=lambda doc: doc.score, reverse=True) # type: ignore
|
||||
return RetrievalDocs(top_documents=top_documents)
|
||||
|
||||
|
||||
@@ -1035,7 +1018,7 @@ def log_agent_sub_question_results(
|
||||
sub_question = sub_question_answer_result.question
|
||||
sub_answer = sub_question_answer_result.answer
|
||||
sub_document_results = _create_citation_format_list(
|
||||
sub_question_answer_result.context_documents
|
||||
sub_question_answer_result.verified_reranked_documents
|
||||
)
|
||||
|
||||
sub_question_object = AgentSubQuestion(
|
||||
|
||||
@@ -18,7 +18,6 @@ import boto3
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
@@ -40,7 +39,6 @@ from onyx.configs.app_configs import POSTGRES_PASSWORD
|
||||
from onyx.configs.app_configs import POSTGRES_POOL_PRE_PING
|
||||
from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
@@ -189,38 +187,20 @@ class SqlEngine:
|
||||
_engine: Engine | None = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
||||
DEFAULT_ENGINE_KWARGS = {
|
||||
"pool_size": 20,
|
||||
"max_overflow": 5,
|
||||
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
|
||||
"pool_recycle": POSTGRES_POOL_RECYCLE,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
|
||||
)
|
||||
|
||||
# Start with base kwargs that are valid for all pool types
|
||||
final_engine_kwargs: dict[str, Any] = {}
|
||||
|
||||
if POSTGRES_USE_NULL_POOL:
|
||||
# if null pool is specified, then we need to make sure that
|
||||
# we remove any passed in kwargs related to pool size that would
|
||||
# cause the initialization to fail
|
||||
final_engine_kwargs.update(engine_kwargs)
|
||||
|
||||
final_engine_kwargs["poolclass"] = pool.NullPool
|
||||
if "pool_size" in final_engine_kwargs:
|
||||
del final_engine_kwargs["pool_size"]
|
||||
if "max_overflow" in final_engine_kwargs:
|
||||
del final_engine_kwargs["max_overflow"]
|
||||
else:
|
||||
final_engine_kwargs["pool_size"] = 20
|
||||
final_engine_kwargs["max_overflow"] = 5
|
||||
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
|
||||
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
|
||||
|
||||
# any passed in kwargs override the defaults
|
||||
final_engine_kwargs.update(engine_kwargs)
|
||||
|
||||
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
||||
engine = create_engine(connection_string, **final_engine_kwargs)
|
||||
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
|
||||
engine = create_engine(connection_string, **merged_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
event.listen(engine, "do_connect", provide_iam_token)
|
||||
@@ -319,21 +299,13 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
|
||||
connect_args["ssl"] = ssl_context
|
||||
|
||||
engine_kwargs = {
|
||||
"connect_args": connect_args,
|
||||
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
|
||||
"pool_recycle": POSTGRES_POOL_RECYCLE,
|
||||
}
|
||||
|
||||
if POSTGRES_USE_NULL_POOL:
|
||||
engine_kwargs["poolclass"] = pool.NullPool
|
||||
else:
|
||||
engine_kwargs["pool_size"] = POSTGRES_API_SERVER_POOL_SIZE
|
||||
engine_kwargs["max_overflow"] = POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
|
||||
_ASYNC_ENGINE = create_async_engine(
|
||||
connection_string,
|
||||
**engine_kwargs,
|
||||
connect_args=connect_args,
|
||||
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
|
||||
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@@ -11,7 +11,8 @@ from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentErrorSummary
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
@@ -40,27 +41,6 @@ def get_last_attempt_for_cc_pair(
|
||||
)
|
||||
|
||||
|
||||
def get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
limit: int,
|
||||
db_session: Session,
|
||||
) -> list[IndexAttempt]:
|
||||
return (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair_id,
|
||||
IndexAttempt.search_settings_id == search_settings_id,
|
||||
IndexAttempt.status.notin_(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
.order_by(IndexAttempt.time_updated.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_index_attempt(
|
||||
db_session: Session, index_attempt_id: int
|
||||
) -> IndexAttempt | None:
|
||||
@@ -635,32 +615,23 @@ def count_unique_cc_pairs_with_successful_index_attempts(
|
||||
|
||||
def create_index_attempt_error(
|
||||
index_attempt_id: int | None,
|
||||
connector_credential_pair_id: int,
|
||||
failure: ConnectorFailure,
|
||||
batch: int | None,
|
||||
docs: list[Document],
|
||||
exception_msg: str,
|
||||
exception_traceback: str,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
doc_summaries = []
|
||||
for doc in docs:
|
||||
doc_summary = DocumentErrorSummary.from_document(doc)
|
||||
doc_summaries.append(doc_summary.to_dict())
|
||||
|
||||
new_error = IndexAttemptError(
|
||||
index_attempt_id=index_attempt_id,
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
document_id=(
|
||||
failure.failed_document.document_id if failure.failed_document else None
|
||||
),
|
||||
document_link=(
|
||||
failure.failed_document.document_link if failure.failed_document else None
|
||||
),
|
||||
entity_id=(failure.failed_entity.entity_id if failure.failed_entity else None),
|
||||
failed_time_range_start=(
|
||||
failure.failed_entity.missed_time_range[0]
|
||||
if failure.failed_entity and failure.failed_entity.missed_time_range
|
||||
else None
|
||||
),
|
||||
failed_time_range_end=(
|
||||
failure.failed_entity.missed_time_range[1]
|
||||
if failure.failed_entity and failure.failed_entity.missed_time_range
|
||||
else None
|
||||
),
|
||||
failure_message=failure.failure_message,
|
||||
is_resolved=False,
|
||||
batch=batch,
|
||||
doc_summaries=doc_summaries,
|
||||
error_msg=exception_msg,
|
||||
traceback=exception_traceback,
|
||||
)
|
||||
db_session.add(new_error)
|
||||
db_session.commit()
|
||||
@@ -678,42 +649,3 @@ def get_index_attempt_errors(
|
||||
|
||||
errors = db_session.scalars(stmt)
|
||||
return list(errors.all())
|
||||
|
||||
|
||||
def count_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
unresolved_only: bool,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(IndexAttemptError)
|
||||
.where(IndexAttemptError.connector_credential_pair_id == cc_pair_id)
|
||||
)
|
||||
if unresolved_only:
|
||||
stmt = stmt.where(IndexAttemptError.is_resolved.is_(False))
|
||||
|
||||
result = db_session.scalar(stmt)
|
||||
return 0 if result is None else result
|
||||
|
||||
|
||||
def get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
unresolved_only: bool,
|
||||
db_session: Session,
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
) -> list[IndexAttemptError]:
|
||||
stmt = select(IndexAttemptError).where(
|
||||
IndexAttemptError.connector_credential_pair_id == cc_pair_id
|
||||
)
|
||||
if unresolved_only:
|
||||
stmt = stmt.where(IndexAttemptError.is_resolved.is_(False))
|
||||
|
||||
# Order by most recent first
|
||||
stmt = stmt.order_by(desc(IndexAttemptError.time_created))
|
||||
|
||||
if page is not None and page_size is not None:
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
@@ -827,19 +827,6 @@ class IndexAttempt(Base):
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# for polling connectors, the start and end time of the poll window
|
||||
# will be set when the index attempt starts
|
||||
poll_range_start: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, default=None
|
||||
)
|
||||
poll_range_end: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, default=None
|
||||
)
|
||||
|
||||
# Points to the last checkpoint that was saved for this run. The pointer here
|
||||
# can be taken to the FileStore to grab the actual checkpoint value
|
||||
checkpoint_pointer: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
@@ -883,13 +870,6 @@ class IndexAttempt(Base):
|
||||
desc("time_updated"),
|
||||
unique=False,
|
||||
),
|
||||
Index(
|
||||
"ix_index_attempt_cc_pair_settings_poll",
|
||||
"connector_credential_pair_id",
|
||||
"search_settings_id",
|
||||
"status",
|
||||
desc("time_updated"),
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -906,33 +886,25 @@ class IndexAttempt(Base):
|
||||
|
||||
|
||||
class IndexAttemptError(Base):
|
||||
"""
|
||||
Represents an error that was encountered during an IndexAttempt.
|
||||
"""
|
||||
|
||||
__tablename__ = "index_attempt_errors"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
index_attempt_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("index_attempt.id"),
|
||||
nullable=False,
|
||||
)
|
||||
connector_credential_pair_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"),
|
||||
nullable=False,
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
document_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
document_link: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
entity_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
failed_time_range_start: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
failed_time_range_end: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
failure_message: Mapped[str] = mapped_column(Text)
|
||||
is_resolved: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# The index of the batch where the error occurred (if looping thru batches)
|
||||
# Just informational.
|
||||
batch: Mapped[int | None] = mapped_column(Integer, default=None)
|
||||
doc_summaries: Mapped[list[Any]] = mapped_column(postgresql.JSONB())
|
||||
error_msg: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
traceback: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
@@ -941,6 +913,21 @@ class IndexAttemptError(Base):
|
||||
# This is the reverse side of the relationship
|
||||
index_attempt = relationship("IndexAttempt", back_populates="error_rows")
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"index_attempt_id",
|
||||
"time_created",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<IndexAttempt(id={self.id!r}, "
|
||||
f"index_attempt_id={self.index_attempt_id!r}, "
|
||||
f"error_msg={self.error_msg!r})>"
|
||||
f"time_created={self.time_created!r}, "
|
||||
)
|
||||
|
||||
|
||||
class SyncRecord(Base):
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import time
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
@@ -221,49 +217,3 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
deployment_name=search_settings.deployment_name,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
||||
def embed_chunks_with_failure_handling(
|
||||
chunks: list[DocAwareChunk],
|
||||
embedder: IndexingEmbedder,
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
"""Tries to embed all chunks in one large batch. If that batch fails for any reason,
|
||||
goes document by document to isolate the failure(s).
|
||||
"""
|
||||
|
||||
# First try to embed all chunks in one batch
|
||||
try:
|
||||
return embedder.embed_chunks(chunks=chunks), []
|
||||
except Exception:
|
||||
logger.exception("Failed to embed chunk batch. Trying individual docs.")
|
||||
# wait a couple seconds to let any rate limits or temporary issues resolve
|
||||
time.sleep(2)
|
||||
|
||||
# Try embedding each document's chunks individually
|
||||
chunks_by_doc: dict[str, list[DocAwareChunk]] = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
chunks_by_doc[chunk.source_document.id].append(chunk)
|
||||
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
failures: list[ConnectorFailure] = []
|
||||
|
||||
for doc_id, chunks_for_doc in chunks_by_doc.items():
|
||||
try:
|
||||
doc_embedded_chunks = embedder.embed_chunks(chunks=chunks_for_doc)
|
||||
embedded_chunks.extend(doc_embedded_chunks)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to embed chunks for document '{doc_id}'")
|
||||
failures.append(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=(
|
||||
chunks_for_doc[0].get_link() if chunks_for_doc else None
|
||||
),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
|
||||
return embedded_chunks, failures
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from typing import Protocol
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import get_access_for_documents
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import INDEXING_EXCEPTION_LIMIT
|
||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.document import fetch_chunk_counts_for_documents
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
@@ -27,6 +29,7 @@ from onyx.db.document import update_docs_updated_at__no_commit
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.document import upsert_documents
|
||||
from onyx.db.document_set import fetch_document_sets_for_documents
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.models import Document as DBDocument
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.tag import create_or_add_document_tag
|
||||
@@ -38,12 +41,10 @@ from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import embed_chunks_with_failure_handling
|
||||
from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
@@ -66,8 +67,6 @@ class IndexingPipelineResult(BaseModel):
|
||||
# number of chunks that were inserted into Vespa
|
||||
total_chunks: int
|
||||
|
||||
failures: list[ConnectorFailure]
|
||||
|
||||
|
||||
class IndexingPipelineProtocol(Protocol):
|
||||
def __call__(
|
||||
@@ -157,10 +156,14 @@ def index_doc_batch_with_handler(
|
||||
document_index: DocumentIndex,
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
attempt_id: int | None,
|
||||
db_session: Session,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
) -> IndexingPipelineResult:
|
||||
index_pipeline_result = IndexingPipelineResult(
|
||||
new_docs=0, total_docs=len(document_batch), total_chunks=0
|
||||
)
|
||||
try:
|
||||
index_pipeline_result = index_doc_batch(
|
||||
chunker=chunker,
|
||||
@@ -173,25 +176,47 @@ def index_doc_batch_with_handler(
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to index document batch: {document_batch}")
|
||||
index_pipeline_result = IndexingPipelineResult(
|
||||
new_docs=0,
|
||||
total_docs=len(document_batch),
|
||||
total_chunks=0,
|
||||
failures=[
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=document.id,
|
||||
document_link=(
|
||||
document.sections[0].link if document.sections else None
|
||||
),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE:
|
||||
logger.error(
|
||||
"NOTE: HTTP Status 507 Insufficient Storage indicates "
|
||||
"you need to allocate more memory or disk space to the "
|
||||
"Vespa/index container."
|
||||
)
|
||||
for document in document_batch
|
||||
],
|
||||
|
||||
if INDEXING_EXCEPTION_LIMIT == 0:
|
||||
raise
|
||||
|
||||
trace = traceback.format_exc()
|
||||
create_index_attempt_error(
|
||||
attempt_id,
|
||||
batch=index_attempt_metadata.batch_num,
|
||||
docs=document_batch,
|
||||
exception_msg=str(e),
|
||||
exception_traceback=trace,
|
||||
db_session=db_session,
|
||||
)
|
||||
logger.exception(
|
||||
f"Indexing batch {index_attempt_metadata.batch_num} failed. msg='{e}' trace='{trace}'"
|
||||
)
|
||||
|
||||
index_attempt_metadata.num_exceptions += 1
|
||||
if index_attempt_metadata.num_exceptions == INDEXING_EXCEPTION_LIMIT:
|
||||
logger.warning(
|
||||
f"Maximum number of exceptions for this index attempt "
|
||||
f"({INDEXING_EXCEPTION_LIMIT}) has been reached. "
|
||||
f"The next exception will abort the indexing attempt."
|
||||
)
|
||||
elif index_attempt_metadata.num_exceptions > INDEXING_EXCEPTION_LIMIT:
|
||||
logger.warning(
|
||||
f"Maximum number of exceptions for this index attempt "
|
||||
f"({INDEXING_EXCEPTION_LIMIT}) has been exceeded."
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Maximum exception limit of {INDEXING_EXCEPTION_LIMIT} exceeded."
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
return index_pipeline_result
|
||||
|
||||
@@ -351,12 +376,8 @@ def index_doc_batch(
|
||||
document_ids=[doc.id for doc in filtered_documents],
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
return IndexingPipelineResult(
|
||||
new_docs=0,
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=0,
|
||||
failures=[],
|
||||
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
|
||||
)
|
||||
|
||||
doc_descriptors = [
|
||||
@@ -369,19 +390,10 @@ def index_doc_batch(
|
||||
logger.debug(f"Starting indexing process for documents: {doc_descriptors}")
|
||||
|
||||
logger.debug("Starting chunking")
|
||||
# NOTE: no special handling for failures here, since the chunker is not
|
||||
# a common source of failure for the indexing pipeline
|
||||
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings, embedding_failures = (
|
||||
embed_chunks_with_failure_handling(
|
||||
chunks=chunks,
|
||||
embedder=embedder,
|
||||
)
|
||||
if chunks
|
||||
else ([], [])
|
||||
)
|
||||
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []
|
||||
|
||||
updatable_ids = [doc.id for doc in ctx.updatable_docs]
|
||||
|
||||
@@ -447,11 +459,7 @@ def index_doc_batch(
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
# in this set
|
||||
(
|
||||
insertion_records,
|
||||
vector_db_write_failures,
|
||||
) = write_chunks_to_vector_db_with_backoff(
|
||||
document_index=document_index,
|
||||
insertion_records = document_index.index(
|
||||
chunks=access_aware_chunks,
|
||||
index_batch_params=IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
|
||||
@@ -511,7 +519,6 @@ def index_doc_batch(
|
||||
new_docs=len([r for r in insertion_records if r.already_existed is False]),
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=len(access_aware_chunks),
|
||||
failures=vector_db_write_failures + embedding_failures,
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -524,6 +531,7 @@ def build_indexing_pipeline(
|
||||
db_session: Session,
|
||||
chunker: Chunker | None = None,
|
||||
ignore_time_skip: bool = False,
|
||||
attempt_id: int | None = None,
|
||||
tenant_id: str | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
@@ -545,6 +553,7 @@ def build_indexing_pipeline(
|
||||
embedder=embedder,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
attempt_id=attempt_id,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
@@ -57,13 +57,6 @@ class DocAwareChunk(BaseChunk):
|
||||
"""Used when logging the identity of a chunk"""
|
||||
return f"{self.source_document.to_short_descriptor()} Chunk ID: {self.chunk_id}"
|
||||
|
||||
def get_link(self) -> str | None:
|
||||
return (
|
||||
self.source_document.sections[0].link
|
||||
if self.source_document.sections
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
class IndexChunk(DocAwareChunk):
|
||||
embeddings: ChunkEmbedding
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _log_insufficient_storage_error(e: Exception) -> None:
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE:
|
||||
logger.error(
|
||||
"NOTE: HTTP Status 507 Insufficient Storage indicates "
|
||||
"you need to allocate more memory or disk space to the "
|
||||
"Vespa/index container."
|
||||
)
|
||||
|
||||
|
||||
def write_chunks_to_vector_db_with_backoff(
|
||||
document_index: DocumentIndex,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]:
|
||||
"""Tries to insert all chunks in one large batch. If that batch fails for any reason,
|
||||
goes document by document to isolate the failure(s).
|
||||
|
||||
IMPORTANT: must pass in whole documents at a time not individual chunks, since the
|
||||
vector DB interface assumes that all chunks for a single document are present.
|
||||
"""
|
||||
|
||||
# first try to write the chunks to the vector db
|
||||
try:
|
||||
return (
|
||||
list(
|
||||
document_index.index(
|
||||
chunks=chunks,
|
||||
index_batch_params=index_batch_params,
|
||||
)
|
||||
),
|
||||
[],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to write chunk batch to vector db. Trying individual docs."
|
||||
)
|
||||
|
||||
# give some specific logging on this common failure case.
|
||||
_log_insufficient_storage_error(e)
|
||||
|
||||
# wait a couple seconds just to give the vector db a chance to recover
|
||||
time.sleep(2)
|
||||
|
||||
# try writing each doc one by one
|
||||
chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
chunks_for_docs[chunk.source_document.id].append(chunk)
|
||||
|
||||
insertion_records: list[DocumentInsertionRecord] = []
|
||||
failures: list[ConnectorFailure] = []
|
||||
for doc_id, chunks_for_doc in chunks_for_docs.items():
|
||||
try:
|
||||
insertion_records.extend(
|
||||
document_index.index(
|
||||
chunks=chunks_for_doc,
|
||||
index_batch_params=index_batch_params,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to write document chunks for '{doc_id}' to vector db"
|
||||
)
|
||||
|
||||
# give some specific logging on this common failure case.
|
||||
_log_insufficient_storage_error(e)
|
||||
|
||||
failures.append(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=(
|
||||
chunks_for_doc[0].get_link() if chunks_for_doc else None
|
||||
),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
|
||||
return insertion_records, failures
|
||||
@@ -52,18 +52,6 @@ litellm.telemetry = False
|
||||
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
|
||||
|
||||
|
||||
class LLMTimeoutError(Exception):
|
||||
"""
|
||||
Exception raised when an LLM call times out.
|
||||
"""
|
||||
|
||||
|
||||
class LLMRateLimitError(Exception):
|
||||
"""
|
||||
Exception raised when an LLM call is rate limited.
|
||||
"""
|
||||
|
||||
|
||||
def _base_msg_to_role(msg: BaseMessage) -> str:
|
||||
if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk):
|
||||
return "user"
|
||||
@@ -401,7 +389,6 @@ class DefaultMultiLLM(LLM):
|
||||
tool_choice: ToolChoiceOptions | None,
|
||||
stream: bool,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
|
||||
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
|
||||
# to a dict representation
|
||||
@@ -409,6 +396,10 @@ class DefaultMultiLLM(LLM):
|
||||
self._record_call(processed_prompt)
|
||||
|
||||
try:
|
||||
print(
|
||||
"model is",
|
||||
f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
|
||||
)
|
||||
return litellm.completion(
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
# model choice
|
||||
@@ -428,7 +419,7 @@ class DefaultMultiLLM(LLM):
|
||||
stream=stream,
|
||||
# model params
|
||||
temperature=0,
|
||||
timeout=timeout_override or self._timeout,
|
||||
timeout=self._timeout,
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
@@ -447,12 +438,6 @@ class DefaultMultiLLM(LLM):
|
||||
except Exception as e:
|
||||
self._record_error(processed_prompt, e)
|
||||
# for break pointing
|
||||
if isinstance(e, litellm.Timeout):
|
||||
raise LLMTimeoutError(e)
|
||||
|
||||
elif isinstance(e, litellm.RateLimitError):
|
||||
raise LLMRateLimitError(e)
|
||||
|
||||
raise e
|
||||
|
||||
@property
|
||||
@@ -473,7 +458,6 @@ class DefaultMultiLLM(LLM):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> BaseMessage:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
@@ -481,12 +465,7 @@ class DefaultMultiLLM(LLM):
|
||||
response = cast(
|
||||
litellm.ModelResponse,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=False,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
prompt, tools, tool_choice, False, structured_response_format
|
||||
),
|
||||
)
|
||||
choice = response.choices[0]
|
||||
@@ -504,31 +483,19 @@ class DefaultMultiLLM(LLM):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
if DISABLE_LITELLM_STREAMING:
|
||||
yield self.invoke(
|
||||
prompt,
|
||||
tools,
|
||||
tool_choice,
|
||||
structured_response_format,
|
||||
timeout_override,
|
||||
)
|
||||
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
|
||||
return
|
||||
|
||||
output = None
|
||||
response = cast(
|
||||
litellm.CustomStreamWrapper,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
prompt, tools, tool_choice, True, structured_response_format
|
||||
),
|
||||
)
|
||||
try:
|
||||
|
||||
@@ -81,7 +81,6 @@ class CustomModelServer(LLM):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> BaseMessage:
|
||||
return self._execute(prompt)
|
||||
|
||||
@@ -91,6 +90,5 @@ class CustomModelServer(LLM):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
yield self._execute(prompt)
|
||||
|
||||
@@ -90,13 +90,12 @@ class LLM(abc.ABC):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> BaseMessage:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._invoke_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format, timeout_override
|
||||
prompt, tools, tool_choice, structured_response_format
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -106,7 +105,6 @@ class LLM(abc.ABC):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> BaseMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -116,13 +114,12 @@ class LLM(abc.ABC):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
messages = self._stream_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format, timeout_override
|
||||
prompt, tools, tool_choice, structured_response_format
|
||||
)
|
||||
|
||||
tokens = []
|
||||
@@ -141,6 +138,5 @@ class LLM(abc.ABC):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -51,6 +51,7 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
|
||||
from onyx.server.documents.connector import router as connector_router
|
||||
from onyx.server.documents.credential import router as credential_router
|
||||
from onyx.server.documents.document import router as document_router
|
||||
from onyx.server.documents.indexing import router as indexing_router
|
||||
from onyx.server.documents.standard_oauth import router as oauth_router
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.folder.api import router as folder_router
|
||||
@@ -237,17 +238,12 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
await close_auth_limiter()
|
||||
|
||||
|
||||
def log_http_error(request: Request, exc: Exception) -> JSONResponse:
|
||||
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
|
||||
status_code = getattr(exc, "status_code", 500)
|
||||
|
||||
if isinstance(exc, BasicAuthenticationError):
|
||||
# For BasicAuthenticationError, just log a brief message without stack trace
|
||||
# (almost always spammy)
|
||||
logger.debug(f"Authentication failed: {str(exc)}")
|
||||
|
||||
elif status_code == 404 and request.url.path == "/metrics":
|
||||
# Log 404 errors for the /metrics endpoint with debug level
|
||||
logger.debug(f"404 error for /metrics endpoint: {str(exc)}")
|
||||
# For BasicAuthenticationError, just log a brief message without stack trace (almost always spam)
|
||||
logger.warning(f"Authentication failed: {str(exc)}")
|
||||
|
||||
elif status_code >= 400:
|
||||
error_msg = f"{str(exc)}\n"
|
||||
@@ -316,6 +312,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, token_rate_limit_settings_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, indexing_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, get_full_openai_assistants_api_router()
|
||||
)
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_ANSWER_SEPARATOR,
|
||||
)
|
||||
|
||||
# Standards
|
||||
SEPARATOR_LINE = "-------"
|
||||
SEPARATOR_LINE_LONG = "---------------"
|
||||
@@ -9,6 +5,8 @@ UNKNOWN_ANSWER = "I do not have enough information to answer this question."
|
||||
NO_RECOVERED_DOCS = "No relevant information recovered"
|
||||
YES = "yes"
|
||||
NO = "no"
|
||||
|
||||
|
||||
# Framing/Support/Template Prompts
|
||||
HISTORY_FRAMING_PROMPT = f"""
|
||||
For more context, here is the history of the conversation so far that preceded this question:
|
||||
@@ -18,43 +16,6 @@ For more context, here is the history of the conversation so far that preceded t
|
||||
""".strip()
|
||||
|
||||
|
||||
COMMON_RAG_RULES = f"""
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. \
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
|
||||
- If the information is empty or irrelevant, just say "{UNKNOWN_ANSWER}".
|
||||
|
||||
- If the information is relevant but not fully conclusive, provide an answer to the extent you can but also specify that \
|
||||
the information is not conclusive and why.
|
||||
|
||||
- When constructing/considering categories, focus less on the question and more on the context actually provided! \
|
||||
Example: if the question is about the products of company A, and the content provided lists a number of products, \
|
||||
do automatically NOT ASSUME that those belong to company A! So you cannot list those as products of company A, despite \
|
||||
the fact that the question is about company A's products. What you should say instead is maybe something like \
|
||||
"Here are a number of products, but I cannot say whether some or all of them belong to company A: \
|
||||
<proceed with listing the products>". It is ABSOLUTELY ESSENTIAL that the answer constructed reflects \
|
||||
actual knowledge. For that matter, also consider the title of the document and other information that may be \
|
||||
provided. If that does not make it clear that - in the example above - the products belong to company A, \
|
||||
then do not list them as products of company A, just maybe as "A list products that may not necessarily \
|
||||
belong to company A". THIS IS IMPORTANT!
|
||||
|
||||
- Related, if the context provides a list of items with associated data or other information that seems \
|
||||
to align with the categories in the question, but does not specify whether the items or the information is \
|
||||
specific to the exact requested category, then present the information with a disclaimer. Use a title such as \
|
||||
"I am not sure whether these items (or the information provided) is specific to [relevant category] or whether \
|
||||
these are all [specific group], but I found this information may be helpful:" \
|
||||
followed by the list of items and associated data/or information discovered.
|
||||
|
||||
- Do not group together items amongst one headline where not all items belong to the category of the headline! \
|
||||
(Example: "Products used by Company A" where some products listed are not built by Company A, but other companies,
|
||||
or it is not clear that the products are built by Company A). Only state what you know for sure!
|
||||
|
||||
- Do NOT perform any calculations in the answer! Just report on facts.
|
||||
|
||||
- If appropriate, organizing your answer in bullet points is often useful.
|
||||
""".strip()
|
||||
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT = "You are an assistant for question-answering tasks."
|
||||
|
||||
ASSISTANT_SYSTEM_PROMPT_PERSONA = f"""
|
||||
@@ -168,44 +129,20 @@ History summary:
|
||||
# Sub-question
|
||||
# Intentionally left a copy in case we want to modify this one differently
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT = f"""
|
||||
Please create a list of no more than 3 sub-questions whose answers would help to inform the answer \
|
||||
to the initial question.
|
||||
|
||||
The purpose for these sub-questions could be:
|
||||
1) decomposition to isolate individual entities (i.e., 'compare sales of company A and company B' -> \
|
||||
Decompose the initial user question into no more than 3 appropriate sub-questions that help to answer the \
|
||||
original question. The purpose for this decomposition may be to:
|
||||
1) isolate individual entities (i.e., 'compare sales of company A and company B' -> \
|
||||
['what are sales for company A', 'what are sales for company B'])
|
||||
|
||||
2) clarification and/or disambiguation of ambiguous terms (i.e., 'what is our success with company A' -> \
|
||||
2) clarify or disambiguate ambiguous terms (i.e., 'what is our success with company A' -> \
|
||||
['what are our sales with company A','what is our market share with company A', \
|
||||
'is company A a reference customer for us', etc.])
|
||||
3) if a term or a metric is essentially clear, but it could relate to various components of an entity and you \
|
||||
are generally familiar with the entity, then you can decompose the question into sub-questions that are more \
|
||||
specific to components (i.e., 'what do we do to improve scalability of product X', 'what do we to to improve \
|
||||
scalability of product X', 'what do we do to improve stability of product X', ...])
|
||||
4) research an area that could really help to answer the question.
|
||||
|
||||
3) if a term or a metric is essentially clear, but it could relate to various aspects of an entity and you \
|
||||
are generally familiar with the entity, then you can create sub-questions that are more \
|
||||
specific (i.e., 'what do we do to improve product X' -> 'what do we do to improve scalability of product X', \
|
||||
'what do we do to improve performance of product X', 'what do we do to improve stability of product X', ...)
|
||||
|
||||
4) research individual questions and areas that should really help to ultimately answer the question.
|
||||
|
||||
Important:
|
||||
|
||||
- Each sub-question should lend itself to be answered by a RAG system. Correspondingly, phrase the question \
|
||||
in a way that is amenable to that. An example set of sub-questions based on an initial question could look like this:
|
||||
'what can I do to improve the performance of workflow X' -> \
|
||||
'what are the settings affecting performance for workflow X', 'are there complaints and bugs related to \
|
||||
workflow X performance', 'what are performance benchmarks for workflow X', ...
|
||||
|
||||
- Consequently, again, don't just decompose, but make sure that the sub-questions have the proper form. I.e., no \
|
||||
'I', etc.
|
||||
|
||||
- Do not(!) create sub-questions that are clarifying question to the person who asked the question, \
|
||||
like making suggestions or asking the user for more information! This is not useful for the actual \
|
||||
question-answering process! You need to take the information from the user as it is given to you! \
|
||||
For example, should the question be of the type 'why does product X perform poorly for customer A', DO NOT create a \
|
||||
sub-question of the type 'what are the settings that customer A uses for product X?'! A valid sub-question \
|
||||
could rather be 'which settings for product X have been shown to lead to poor performance for customers?'
|
||||
|
||||
|
||||
And here is the initial question to create sub-questions for, so that you have the full context:
|
||||
Here is the initial question to decompose:
|
||||
{SEPARATOR_LINE}
|
||||
{{question}}
|
||||
{SEPARATOR_LINE}
|
||||
@@ -213,79 +150,7 @@ And here is the initial question to create sub-questions for, so that you have t
|
||||
{{history}}
|
||||
|
||||
Do NOT include any text in your answer outside of the list of sub-questions!
|
||||
Please formulate your answer as a newline-separated list of questions like so (and please ONLY ANSWER WITH THIS LIST! Do not \
|
||||
add any explanations or other text!):
|
||||
|
||||
<sub-question>
|
||||
<sub-question>
|
||||
<sub-question>
|
||||
...
|
||||
|
||||
Answer:
|
||||
""".strip()
|
||||
|
||||
# INITIAL PHASE - AWARE OF REFINEMENT
|
||||
# Sub-question
|
||||
# Suggest augmenting question generation as well, that a future refinement phase could use
|
||||
# to generate new questions
|
||||
# Intentionally left a copy in case we want to modify this one differently
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT = f"""
|
||||
Please create a list of no more than 3 sub-questions whose answers would help to inform the answer \
|
||||
to the initial question.
|
||||
|
||||
The purpose for these sub-questions could be:
|
||||
1) decomposition to isolate individual entities (i.e., 'compare sales of company A and company B' -> \
|
||||
['what are sales for company A', 'what are sales for company B'])
|
||||
|
||||
2) clarification and/or disambiguation of ambiguous terms (i.e., 'what is our success with company A' -> \
|
||||
['what are our sales with company A','what is our market share with company A', \
|
||||
'is company A a reference customer for us', etc.])
|
||||
|
||||
3) if a term or a metric is essentially clear, but it could relate to various aspects of an entity and you \
|
||||
are generally familiar with the entity, then you can create sub-questions that are more \
|
||||
specific (i.e., 'what do we do to improve product X' -> 'what do we do to improve scalability of product X', \
|
||||
'what do we do to improve performance of product X', 'what do we do to improve stability of product X', ...)
|
||||
|
||||
4) research individual questions and areas that should really help to ultimately answer the question.
|
||||
|
||||
5) if meaningful, find relevant facts that may inform another set of sub-questions generate after the set you \
|
||||
create now are answered. Example: 'which products have we implemented at company A, and is this different to \
|
||||
its competitors?' could potentially create sub-questions 'what products have we implemented at company A', \
|
||||
and 'who are the competitors of company A'. The additional round of sub-question generation which sees the \
|
||||
answers for this initial round of sub-question creation could then use the answer to the second sub-question \
|
||||
(which could be 'company B and C are competitors of company A') to then ask 'which products have we implemented \
|
||||
at company B', 'which products have we implemented at company C'...
|
||||
|
||||
Important:
|
||||
|
||||
- Each sub-question should lend itself to be answered by a RAG system. Correspondingly, phrase the question \
|
||||
in a way that is amenable to that. An example set of sub-questions based on an initial question could look like this:
|
||||
'what can I do to improve the performance of workflow X' -> \
|
||||
'what are the settings affecting performance for workflow X', 'are there complaints and bugs related to \
|
||||
workflow X performance', 'what are performance benchmarks for workflow X', ...
|
||||
|
||||
- Consequently, again, don't just decompose, but make sure that the sub-questions have the proper form. I.e., no \
|
||||
'I', etc.
|
||||
|
||||
- Do not(!) create sub-questions that are clarifying question to the person who asked the question, \
|
||||
like making suggestions or asking the user for more information! This is not useful for the actual \
|
||||
question-answering process! You need to take the information from the user as it is given to you! \
|
||||
For example, should the question be of the type 'why does product X perform poorly for customer A', DO NOT create a \
|
||||
sub-question of the type 'what are the settings that customer A uses for product X?'! A valid sub-question \
|
||||
could rather be 'which settings for product X have been shown to lead to poor performance for customers?'
|
||||
|
||||
|
||||
And here is the initial question to create sub-questions for:
|
||||
{SEPARATOR_LINE}
|
||||
{{question}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
{{history}}
|
||||
|
||||
Do NOT include any text in your answer outside of the list of sub-questions!
|
||||
Please formulate your answer as a newline-separated list of questions like so (and please ONLY ANSWER WITH THIS LIST! Do not \
|
||||
add any explanations or other text!):
|
||||
|
||||
Please formulate your answer as a newline-separated list of questions like so:
|
||||
<sub-question>
|
||||
<sub-question>
|
||||
<sub-question>
|
||||
@@ -297,47 +162,23 @@ Answer:
|
||||
|
||||
# TODO: combine shared pieces with INITIAL_QUESTION_DECOMPOSITION_PROMPT
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH = f"""
|
||||
Please create a list of no more than 3 sub-questions whose answers would help to inform the answer \
|
||||
to the initial question.
|
||||
|
||||
The purpose for these sub-questions could be:
|
||||
1) decomposition to isolate individual entities (i.e., 'compare sales of company A and company B' -> \
|
||||
Decompose the initial user question into no more than 3 appropriate sub-questions that help to answer the \
|
||||
original question. The purpose for this decomposition may be to:
|
||||
1) isolate individual entities (i.e., 'compare sales of company A and company B' -> \
|
||||
['what are sales for company A', 'what are sales for company B'])
|
||||
|
||||
2) clarification and/or disambiguation of ambiguous terms (i.e., 'what is our success with company A' -> \
|
||||
2) clarify or disambiguate ambiguous terms (i.e., 'what is our success with company A' -> \
|
||||
['what are our sales with company A','what is our market share with company A', \
|
||||
'is company A a reference customer for us', etc.])
|
||||
|
||||
3) if a term or a metric is essentially clear, but it could relate to various aspects of an entity and you \
|
||||
are generally familiar with the entity, then you can create sub-questions that are more \
|
||||
specific (i.e., 'what do we do to improve product X' -> 'what do we do to improve scalability of product X', \
|
||||
'what do we do to improve performance of product X', 'what do we do to improve stability of product X', ...)
|
||||
|
||||
4) research individual questions and areas that should really help to ultimately answer the question.
|
||||
|
||||
Important:
|
||||
|
||||
- Each sub-question should lend itself to be answered by a RAG system. Correspondingly, phrase the question \
|
||||
in a way that is amenable to that. An example set of sub-questions based on an initial question could look like this:
|
||||
'what can I do to improve the performance of workflow X' -> \
|
||||
'what are the settings affecting performance for workflow X', 'are there complaints and bugs related to \
|
||||
workflow X performance', 'what are performance benchmarks for workflow X', ...
|
||||
|
||||
- Consequently, again, don't just decompose, but make sure that the sub-questions have the proper form. I.e., no \
|
||||
'I', etc.
|
||||
|
||||
- Do not(!) create sub-questions that are clarifying question to the person who asked the question, \
|
||||
like making suggestions or asking the user for more information! This is not useful for the actual \
|
||||
question-answering process! You need to take the information from the user as it is given to you! \
|
||||
For example, should the question be of the type 'why does product X perform poorly for customer A', DO NOT create a \
|
||||
sub-question of the type 'what are the settings that customer A uses for product X?'! A valid sub-question \
|
||||
could rather be 'which settings for product X have been shown to lead to poor performance for customers?'
|
||||
|
||||
3) if a term or a metric is essentially clear, but it could relate to various components of an entity and you \
|
||||
are generally familiar with the entity, then you can decompose the question into sub-questions that are more \
|
||||
specific to components (i.e., 'what do we do to improve scalability of product X', 'what do we to to improve \
|
||||
scalability of product X', 'what do we do to improve stability of product X', ...])
|
||||
4) research an area that could really help to answer the question.
|
||||
|
||||
To give you some context, you will see below also some documents that may relate to the question. Please only \
|
||||
use this information to learn what the question is approximately asking about, but do not focus on the details \
|
||||
to construct the sub-questions! Also, some of the entities, relationships and terms that are in the dataset may \
|
||||
not be in these few documents, so DO NOT focus too much on the documents when constructing the sub-questions! \
|
||||
not be in these few documents, so DO NOT focussed too much on the documents when constructing the sub-questions! \
|
||||
Decomposition and disambiguations are most important!
|
||||
|
||||
Here are the sample docs to give you some context:
|
||||
@@ -345,7 +186,7 @@ Here are the sample docs to give you some context:
|
||||
{{sample_doc_str}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
And here is the initial question to create sub-questions for, so that you have the full context:
|
||||
And here is the initial question to decompose:
|
||||
{SEPARATOR_LINE}
|
||||
{{question}}
|
||||
{SEPARATOR_LINE}
|
||||
@@ -353,9 +194,7 @@ And here is the initial question to create sub-questions for, so that you have t
|
||||
{{history}}
|
||||
|
||||
Do NOT include any text in your answer outside of the list of sub-questions!\
|
||||
Please formulate your answer as a newline-separated list of questions like so (and please ONLY ANSWER WITH THIS LIST! Do not \
|
||||
add any explanations or other text!):
|
||||
|
||||
Please formulate your answer as a newline-separated list of questions like so:
|
||||
<sub-question>
|
||||
<sub-question>
|
||||
<sub-question>
|
||||
@@ -364,84 +203,6 @@ add any explanations or other text!):
|
||||
Answer:
|
||||
""".strip()
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH_ASSUMING_REFINEMENT = f"""
|
||||
Please create a list of no more than 3 sub-questions whose answers would help to inform the answer \
|
||||
to the initial question.
|
||||
|
||||
The purpose for these sub-questions could be:
|
||||
1) decomposition to isolate individual entities (i.e., 'compare sales of company A and company B' -> \
|
||||
['what are sales for company A', 'what are sales for company B'])
|
||||
|
||||
2) clarification and/or disambiguation of ambiguous terms (i.e., 'what is our success with company A' -> \
|
||||
['what are our sales with company A','what is our market share with company A', \
|
||||
'is company A a reference customer for us', etc.])
|
||||
|
||||
3) if a term or a metric is essentially clear, but it could relate to various aspects of an entity and you \
|
||||
are generally familiar with the entity, then you can create sub-questions that are more \
|
||||
specific (i.e., 'what do we do to improve product X' -> 'what do we do to improve scalability of product X', \
|
||||
'what do we do to improve performance of product X', 'what do we do to improve stability of product X', ...)
|
||||
|
||||
4) research individual questions and areas that should really help to ultimately answer the question.
|
||||
|
||||
5) if applicable and useful, consider using sub-questions to gather relevant information that can inform a \
|
||||
subsequent set of sub-questions. The answers to your initial sub-questions will be available when generating \
|
||||
the next set.
|
||||
For example, if you start with the question, "Which products have we implemented at Company A, and how does \
|
||||
this compare to its competitors?" you might first create sub-questions like "What products have we implemented \
|
||||
at Company A?" and "Who are the competitors of Company A?"
|
||||
The answer to the second sub-question, such as "Company B and C are competitors of Company A," can then be used \
|
||||
to generate more specific sub-questions in the next round, like "Which products have we implemented at Company B?" \
|
||||
and "Which products have we implemented at Company C?"
|
||||
|
||||
You'll be the judge!
|
||||
|
||||
Important:
|
||||
|
||||
- Each sub-question should lend itself to be answered by a RAG system. Correspondingly, phrase the question \
|
||||
in a way that is amenable to that. An example set of sub-questions based on an initial question could look like this:
|
||||
'what can I do to improve the performance of workflow X' -> \
|
||||
'what are the settings affecting performance for workflow X', 'are there complaints and bugs related to \
|
||||
workflow X performance', 'what are performance benchmarks for workflow X', ...
|
||||
|
||||
- Consequently, again, don't just decompose, but make sure that the sub-questions have the proper form. I.e., no \
|
||||
'I', etc.
|
||||
|
||||
- Do not(!) create sub-questions that are clarifying question to the person who asked the question, \
|
||||
like making suggestions or asking the user for more information! This is not useful for the actual \
|
||||
question-answering process! You need to take the information from the user as it is given to you! \
|
||||
For example, should the question be of the type 'why does product X perform poorly for customer A', DO NOT create a \
|
||||
sub-question of the type 'what are the settings that customer A uses for product X?'! A valid sub-question \
|
||||
could rather be 'which settings for product X have been shown to lead to poor performance for customers?'
|
||||
|
||||
To give you some context, you will see below also some documents that may relate to the question. Please only \
|
||||
use this information to learn what the question is approximately asking about, but do not focus on the details \
|
||||
to construct the sub-questions! Also, some of the entities, relationships and terms that are in the dataset may \
|
||||
not be in these few documents, so DO NOT focus too much on the documents when constructing the sub-questions! \
|
||||
Decomposition and disambiguations are most important!
|
||||
|
||||
Here are the sample docs to give you some context:
|
||||
{SEPARATOR_LINE}
|
||||
{{sample_doc_str}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
And here is the initial question to create sub-questions for, so that you have the full context:
|
||||
{SEPARATOR_LINE}
|
||||
{{question}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
{{history}}
|
||||
|
||||
Do NOT include any text in your answer outside of the list of sub-questions!\
|
||||
Please formulate your answer as a newline-separated list of questions like so (and please ONLY ANSWER WITH THIS LIST! Do not \
|
||||
add any explanations or other text!):
|
||||
|
||||
<sub-question>
|
||||
<sub-question>
|
||||
<sub-question>
|
||||
...
|
||||
|
||||
Answer:
|
||||
""".strip()
|
||||
|
||||
# Retrieval
|
||||
QUERY_REWRITING_PROMPT = f"""
|
||||
@@ -496,35 +257,23 @@ Answer:
|
||||
""".strip()
|
||||
|
||||
|
||||
# Sub-Question Answer Generation
|
||||
# Sub-Question Anser Generation
|
||||
SUB_QUESTION_RAG_PROMPT = f"""
|
||||
Use the context provided below - and only the provided context - to answer the given question. \
|
||||
(Note that the answer is in service of answering a broader question, given below as 'motivation').
|
||||
|
||||
Make sure that you keep all relevant information, specifically as it concerns the ultimate goal. \
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the \
|
||||
question based on the context, say "{UNKNOWN_ANSWER}". It is a matter of life and death that you do NOT \
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal. \
|
||||
(But keep other details as well.)
|
||||
|
||||
{COMMON_RAG_RULES}
|
||||
|
||||
- Make sure that you only state what you actually can positively learn from the provided context! Particularly \
|
||||
don't make assumptions! Example: if i) a question you should answer is asking for products of companies that \
|
||||
are competitors of company A, and ii) the context mentions products of companies A, B, C, D, E, etc., do NOT assume \
|
||||
that B, C, D, E, etc. are competitors of A! All you know is that these are products of a number of companies, and you \
|
||||
would have to rely on another question - that you do not have access to - to learn which companies are competitors of A.
|
||||
Correspondingly, you should not say that these are the products of competitors of A, but rather something like \
|
||||
"Here are some products of various companies".
|
||||
|
||||
It is critical that you provide inline citations in the format [D1], [D2], [D3], etc! Please use format [D1][D2] and NOT \
|
||||
[D1, D2] format if you cite two or more documents together! \
|
||||
It is critical that you provide inline citations in the format [D1], [D2], [D3], etc! \
|
||||
It is important that the citation is close to the information it supports. \
|
||||
Proper citations are very important to the user!
|
||||
|
||||
Here is the document context for you to consider:
|
||||
{SEPARATOR_LINE}
|
||||
{{context}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
For your general information, here is the ultimate motivation for the question you need to answer:
|
||||
For your general information, here is the ultimate motivation:
|
||||
{SEPARATOR_LINE}
|
||||
{{original_question}}
|
||||
{SEPARATOR_LINE}
|
||||
@@ -534,8 +283,12 @@ And here is the actual question I want you to answer based on the context above
|
||||
{{question}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Please keep your answer brief and concise, and focus on facts and data. (Again, only state what you see in the documents \
|
||||
for sure and communicate if/in what way this may or may not relate to the question you need to answer!)
|
||||
Here is the context:
|
||||
{SEPARATOR_LINE}
|
||||
{{context}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Please keep your answer brief and concise, and focus on facts and data.
|
||||
|
||||
Answer:
|
||||
""".strip()
|
||||
@@ -568,18 +321,22 @@ Use the information provided below - and only the provided information - to answ
|
||||
|
||||
The information provided below consists of:
|
||||
1) a number of answered sub-questions - these are very important to help you organize your thoughts and your answer
|
||||
2) a number of documents that are deemed relevant for the question.
|
||||
2) a number of documents that deemed relevant for the question.
|
||||
|
||||
{{history}}
|
||||
|
||||
It is critical that you provide proper inline citations to documents in the format [D1], [D2], [D3], etc.! \
|
||||
It is critical that you provide prover inline citations to documents in the format [D1], [D2], [D3], etc.! \
|
||||
It is important that the citation is close to the information it supports. If you have multiple citations that support \
|
||||
a fact, please cite for example as [D1][D3], or [D2][D4], etc. \
|
||||
Feel free to also cite sub-questions in addition to documents, but make sure that you have documents cited with the \
|
||||
sub-question citation. If you want to cite both a document and a sub-question, please use [D1][Q3], or [D2][D7][Q4], etc. \
|
||||
Again, please NEVER cite sub-questions without a document citation! Proper citations are very important for the user!
|
||||
|
||||
{COMMON_RAG_RULES}
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. \
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is empty or irrelevant, just say "{UNKNOWN_ANSWER}".
|
||||
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
|
||||
|
||||
Again, you should be sure that the answer is supported by the information provided!
|
||||
|
||||
@@ -604,9 +361,7 @@ And here is the question I want you to answer based on the information above:
|
||||
{{question}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Please keep your answer brief and concise, and focus on facts and data. (Again, only state what you see in the documents for \
|
||||
sure and communicate if/in what way this may or may not relate to the question you need to answer! Use the answered \
|
||||
sub-questions as well, but be cautious and reconsider the docments again for validation.)
|
||||
Please keep your answer brief and concise, and focus on facts and data.
|
||||
|
||||
Answer:
|
||||
""".strip()
|
||||
@@ -621,7 +376,11 @@ The information provided below consists of a number of documents that were deeme
|
||||
|
||||
{{history}}
|
||||
|
||||
{COMMON_RAG_RULES}
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. \
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is irrelevant, just say "{UNKNOWN_ANSWER}".
|
||||
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
|
||||
|
||||
Again, you should be sure that the answer is supported by the information provided!
|
||||
|
||||
@@ -640,8 +399,7 @@ And here is the question I want you to answer based on the context above:
|
||||
{{question}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Please keep your answer brief and concise, and focus on facts and data. (Again, only state what you see in the documents \
|
||||
for sure and communicate if/in what way this may or may not relate to the question you need to answer!)
|
||||
Please keep your answer brief and concise, and focus on facts and data.
|
||||
|
||||
Answer:
|
||||
""".strip()
|
||||
@@ -681,12 +439,6 @@ independently without the original question available
|
||||
- For each sub-question, please also provide a search term that can be used to retrieve relevant documents from a document store.
|
||||
- Consider specifically the sub-questions that were suggested but not answered. This is a sign that they are not answerable \
|
||||
with the available context, and you should not ask similar questions.
|
||||
- Do not(!) create sub-questions that are clarifying question to the person who asked the question, \
|
||||
like making suggestions or asking the user for more information! This is not useful for the actual \
|
||||
question-answering process! You need to take the information from the user as it is given to you! \
|
||||
For example, should the question be of the type 'why does product X perform poorly for customer A', DO NOT create a \
|
||||
sub-question of the type 'what are the settings that customer A uses for product X?'! A valid sub-question \
|
||||
could rather be 'which settings for product X have been shown to lead to poor performance for customers?'
|
||||
|
||||
Here is the initial question:
|
||||
{SEPARATOR_LINE}
|
||||
@@ -722,111 +474,7 @@ objects/relationships/terms you can ask about! Do not ask about entities, terms
|
||||
Again, please find questions that are NOT overlapping too much with the already answered sub-questions or those that \
|
||||
already were suggested and failed. In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Generate the list of questions separated by one new line like this (and please ONLY ANSWER WITH THIS LIST! Do not \
|
||||
add any explanations or other text!):
|
||||
|
||||
<sub-question 1>
|
||||
<sub-question 2>
|
||||
<sub-question 3>
|
||||
...""".strip()
|
||||
|
||||
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT_W_INITIAL_SUBQUESTION_ANSWERS = f"""
|
||||
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite good enough. \
|
||||
Also, some sub-questions had been answered and this information has been used to provide the initial answer. \
|
||||
Some other subquestions may have been suggested based on little knowledge, but they were not directly answerable. \
|
||||
Also, some entities, relationships and terms are given to you so that you have an idea of how the available data looks like.
|
||||
|
||||
Your role is to generate 2-4 new sub-questions that would help to answer the initial question, considering:
|
||||
|
||||
1) The initial question
|
||||
2) The initial answer that was found to be unsatisfactory
|
||||
3) The sub-questions that were answered AND their answers
|
||||
4) The sub-questions that were suggested but not answered (and that you should not repeat!)
|
||||
5) The entities, relationships and terms that were extracted from the context
|
||||
|
||||
The individual questions should be answerable by a good RAG system. So a good idea would be to use the sub-questions to \
|
||||
resolve ambiguities and/or to separate the question for different entities that may be involved in the original question, \
|
||||
but in a way that does not duplicate questions that were already tried.
|
||||
|
||||
Additional Guidelines:
|
||||
|
||||
- The new sub-questions should be specific to the question and provide richer context for the question, resolve ambiguities, \
|
||||
or address shortcoming of the initial answer
|
||||
|
||||
- Each new sub-question - when answered - should be relevant for the answer to the original question
|
||||
|
||||
- The new sub-questions should be free from comparisons, ambiguities,judgements, aggregations, or any other complications that \
|
||||
may require extra context
|
||||
|
||||
- The new sub-questions MUST have the full context of the original question so that it can be executed by a RAG system \
|
||||
independently without the original question available
|
||||
Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
|
||||
- For each new sub-question, please also provide a search term that can be used to retrieve relevant documents \
|
||||
from a document store.
|
||||
|
||||
- Consider specifically the sub-questions that were suggested but not answered. This is a sign that they are not answerable \
|
||||
with the available context, and you should not ask similar questions.
|
||||
|
||||
- Pay attention to the answers of previous sub-question to make your sub-questions more specific! \
|
||||
Often the initial sub-questions were set up to give you critical information that you should use to generate new sub-questions.\
|
||||
For example, if the answer to a an earlier sub-question is \
|
||||
'Company B and C are competitors of Company A', you should not ask now a new sub-question involving the term 'competitors', \
|
||||
as you already have the information to create a more precise question - you should instead explicitly reference \
|
||||
'Company B' and 'Company C' in your new sub-questions, as these are the competitors based on the previously answered question.
|
||||
|
||||
- Be precise(!) and don't make inferences you cannot be sure about! For example, in the previous example \
|
||||
where Company B and Company C were identified as competitors of Company A, and then you also get information on \
|
||||
companies D and E, do not make the inference that these are also competitors of Company A! Stick to the information you have!
|
||||
(Also, don't assume that companies B and C arethe only competitors of A, unless stated!)
|
||||
|
||||
- Do not(!) create sub-questions that are clarifying question *to the person who asked the question*, \
|
||||
like making suggestions or asking the user for more information! This is not useful for the actual \
|
||||
question-answering process! You need to take the information from the user as it is given to you! \
|
||||
For example, should the question be of the type 'why does product X perform poorly for customer A', DO NOT create a \
|
||||
sub-question of the type 'what are the settings that customer A uses for product X?'! A valid sub-question \
|
||||
could rather be 'which settings for product X have been shown to lead to poor performance for customers?'
|
||||
|
||||
Here is the initial question:
|
||||
{SEPARATOR_LINE}
|
||||
{{question}}
|
||||
{SEPARATOR_LINE}
|
||||
{{history}}
|
||||
|
||||
Here is the initial sub-optimal answer:
|
||||
{SEPARATOR_LINE}
|
||||
{{base_answer}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Here are the sub-questions that were answered:
|
||||
{SEPARATOR_LINE}
|
||||
{{answered_subquestions_with_answers}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Here are the sub-questions that were suggested but not answered:
|
||||
{SEPARATOR_LINE}
|
||||
{{failed_sub_questions}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
{SEPARATOR_LINE}
|
||||
{{entity_term_extraction_str}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the main question. \
|
||||
Specifically pay attention also to the entities, relationships and terms extracted, as these indicate what type of \
|
||||
objects/relationships/terms you can ask about! Do not ask about entities, terms or relationships that are not mentioned \
|
||||
in the 'entities, relationships and terms' section.
|
||||
|
||||
Again, please find questions that are NOT overlapping too much with the already answered sub-questions or those that \
|
||||
already were suggested and failed. In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Generate the list of questions separated by one new line like this (and please ONLY ANSWER WITH THIS LIST! Do not \
|
||||
add any explanations or other text!):
|
||||
|
||||
Generate the list of questions separated by one new line like this:
|
||||
<sub-question 1>
|
||||
<sub-question 2>
|
||||
<sub-question 3>
|
||||
@@ -841,7 +489,7 @@ Your task is to improve on a given answer to a question, as the initial answer w
|
||||
Use the information provided below - and only the provided information - to write your new and improved answer.
|
||||
|
||||
The information provided below consists of:
|
||||
1) an initial answer that was given but likely found to be lacking in some way.
|
||||
1) an initial answer that was given but found to be lacking in some way.
|
||||
2) a number of answered sub-questions - these are very important(!) and definitely should help you to answer the main \
|
||||
question. Note that the sub-questions have a type, 'initial' and 'refined'. The 'initial' ones were available for the \
|
||||
creation of the initial answer, but the 'refined' were not, they are new. So please use the 'refined' sub-questions in \
|
||||
@@ -851,7 +499,6 @@ particular to update/extend/correct/enrich the initial answer and to add more de
|
||||
the relevant document for a fact!
|
||||
|
||||
It is critical that you provide proper inline citations to documents in the format [D1], [D2], [D3], etc! \
|
||||
Please use format [D1][D2] and NOT [D1, D2] format if you cite two or more documents together! \
|
||||
It is important that the citation is close to the information it supports. \
|
||||
DO NOT just list all of the citations at the very end. \
|
||||
Feel free to also cite sub-questions in addition to documents, \
|
||||
@@ -862,7 +509,14 @@ Proper citations are very important for the user!
|
||||
|
||||
{{history}}
|
||||
|
||||
{COMMON_RAG_RULES}
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. \
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is empty or irrelevant, just say "{UNKNOWN_ANSWER}".
|
||||
- If the information is relevant but not fully conclusive, provide an answer to the extent you can but also specify that \
|
||||
the information is not conclusive and why.
|
||||
- Ignore any existing citations within the answered sub-questions, like [D1]... and [Q2]! The citations you will need to \
|
||||
use will need to refer to the documents (and sub-questions) that you are explicitly presented with below!
|
||||
|
||||
Again, you should be sure that the answer is supported by the information provided!
|
||||
|
||||
@@ -891,9 +545,7 @@ Lastly, here is the main question I want you to answer based on the information
|
||||
{{question}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Please keep your answer brief and concise, and focus on facts and data. (Again, only state what you see in the documents for \
|
||||
sure and communicate if/in what way this may or may not relate to the question you need to answer! Use the answered \
|
||||
sub-questions as well, but be cautious and reconsider the docments again for validation.)
|
||||
Please keep your answer brief and concise, and focus on facts and data.
|
||||
|
||||
Answer:
|
||||
""".strip()
|
||||
@@ -909,13 +561,18 @@ The information provided below consists of:
|
||||
2) a number of documents that were also deemed relevant for the question.
|
||||
|
||||
It is critical that you provide proper inline citations to documents in the format [D1], [D2], [D3], etc! \
|
||||
Please use format [D1][D2] and NOT [D1, D2] format if you cite two or more documents together! \
|
||||
It is important that the citation is close to the information it supports. \
|
||||
DO NOT just list all of the citations at the very end of your response. Citations are very important for the user!
|
||||
|
||||
{{history}}
|
||||
|
||||
{COMMON_RAG_RULES}
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer. \
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is empty or irrelevant, just say "{UNKNOWN_ANSWER}".
|
||||
- If the information is relevant but not fully conclusive, provide an answer to the extent you can but also specify that \
|
||||
the information is not conclusive and why.
|
||||
|
||||
Again, you should be sure that the answer is supported by the information provided!
|
||||
|
||||
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones, \
|
||||
@@ -940,103 +597,11 @@ Lastly, here is the question I want you to answer based on the information above
|
||||
{{question}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Please keep your answer brief and concise, and focus on facts and data. (Again, only state what you see in the documents for \
|
||||
sure and communicate if/in what way this may or may not relate to the question you need to answer!)
|
||||
Please keep your answer brief and concise, and focus on facts and data.
|
||||
|
||||
Answer:
|
||||
""".strip()
|
||||
|
||||
REFINED_ANSWER_VALIDATION_PROMPT = f"""
|
||||
{{persona_specification}}
|
||||
|
||||
Your task is to verify whether a given answer is truthful and accurate, and supported by the facts that you \
|
||||
will be provided with.
|
||||
|
||||
The information provided below consists of:
|
||||
|
||||
1) a question that needed to be answered
|
||||
|
||||
2) a proposed answer to the question, whose accuracy you should assess
|
||||
|
||||
3) potentially, a brief summary of the history of the conversation thus far, as it may give more context \
|
||||
to the question. Note that the statements in the history are NOT considered as facts, ONLY but serve to to \
|
||||
give context to the question.
|
||||
|
||||
4) a number of answered sub-questions - you can take the answers as facts for these purposes.
|
||||
|
||||
5) a number of relevant documents that should support the answer and that you should use as fact, \
|
||||
i.e., if a statement in the document backs up a statement in the answer, then that statement in the answer \
|
||||
should be considered as true.
|
||||
|
||||
|
||||
IMPORTANT RULES AND CONSIDERATIONS:
|
||||
|
||||
- Please consider the statements made in the proposed answer and assess whether they are truthful and accurate, based \
|
||||
on the provided sub-answered and the documents. (Again, the history is NOT considered as facts!)
|
||||
|
||||
- Look in particular for:
|
||||
* material statements that are not supported by the sub-answered or the documents
|
||||
* assignments and groupings that are not supported, like company A is competitor of company B, but this is not \
|
||||
explicitly supported by documents or sub-answers, guesses or interpretations unless explicitly asked for
|
||||
|
||||
- look also at the citations in the proposed answer and assess whether they are appropriate given the statements \
|
||||
made in the proposed answer that cites the document.
|
||||
|
||||
- Are items grouped together amongst one headline where not all items belong to the category of the headline? \
|
||||
(Example: "Products used by Company A" where some products listed are not used by Company A)
|
||||
|
||||
- Does the proposed answer address the question in full?
|
||||
|
||||
- Is the answer specific to the question? Example: if the question asks for the prices for products by Company A, \
|
||||
but the answer lists the prices for products by Company A and Company B, or products it cannot be sure are by \
|
||||
Company A, then this is not quite specific enough to the question and the answer should be rejected.
|
||||
|
||||
- Similarly, if the question asks for properties of a certain class but the proposed answer lists or includes entities \
|
||||
that are not of that class without very explicitly saying so, then the answer should be considered inaccurate.
|
||||
|
||||
- If there are any calculations in the proposed answer that are not supported by the documents, they need to be tested. \
|
||||
If any calculation is wrong, the proposed answer should be considered as not trustworthy.
|
||||
|
||||
|
||||
Here is the information:
|
||||
{SEPARATOR_LINE_LONG}
|
||||
|
||||
QUESTION:
|
||||
{SEPARATOR_LINE}
|
||||
{{question}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
PROPOSED ANSWER:
|
||||
{SEPARATOR_LINE}
|
||||
{{proposed_answer}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Here is the additional contextual information:
|
||||
{SEPARATOR_LINE_LONG}
|
||||
|
||||
{{history}}
|
||||
|
||||
Sub-questions and their answers (to be considered as facts):
|
||||
{SEPARATOR_LINE}
|
||||
{{answered_sub_questions}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
And here are the relevant documents that support the sub-question answers, and that are relevant for the actual question:
|
||||
{SEPARATOR_LINE}
|
||||
{{relevant_docs}}
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
|
||||
Please think through this step by step. Format your response just as a string in the following format:
|
||||
|
||||
Analysis: <think through your reasoning as outlined in the 'IMPORTANT RULES AND CONSIDERATIONS' section above, \
|
||||
but keep it short. Come to a conclusion whether the proposed answer can be trusted>
|
||||
Comments: <state your condensed comments you would give to a user reading the proposed answer, regarding the accuracy and \
|
||||
specificity.>
|
||||
{AGENT_ANSWER_SEPARATOR} <answer here only with yes or no, whether the proposed answer can be trusted. Base this on your \
|
||||
analysis, but only say 'yes' (trustworthy) or 'no' (not trustworthy)>
|
||||
""".strip()
|
||||
|
||||
|
||||
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT = f"""
|
||||
For the given question, please compare the initial answer and the refined answer and determine if the refined answer is \
|
||||
|
||||
@@ -61,10 +61,10 @@ def _create_indexable_chunks(
|
||||
doc_updated_at=None,
|
||||
primary_owners=[],
|
||||
secondary_owners=[],
|
||||
chunk_count=preprocessed_doc["chunk_ind"] + 1,
|
||||
chunk_count=1,
|
||||
)
|
||||
|
||||
ids_to_documents[document.id] = document
|
||||
if preprocessed_doc["chunk_ind"] == 0:
|
||||
ids_to_documents[document.id] = document
|
||||
|
||||
chunk = DocMetadataAwareIndexChunk(
|
||||
chunk_id=preprocessed_doc["chunk_ind"],
|
||||
@@ -92,7 +92,6 @@ def _create_indexable_chunks(
|
||||
boost=DEFAULT_BOOST,
|
||||
large_chunk_id=None,
|
||||
)
|
||||
|
||||
chunks.append(chunk)
|
||||
|
||||
return list(ids_to_documents.values()), chunks
|
||||
@@ -193,7 +192,6 @@ def seed_initial_documents(
|
||||
last_successful_index_time=last_index_time,
|
||||
seeding_flow=True,
|
||||
)
|
||||
|
||||
cc_pair_id = cast(int, result.data)
|
||||
processed_docs = fetch_versioned_implementation(
|
||||
"onyx.seeding.load_docs",
|
||||
@@ -251,5 +249,4 @@ def seed_initial_documents(
|
||||
.values(chunk_count=doc.chunk_count)
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
kv_store.store(KV_DOCUMENTS_SEEDED_KEY, True)
|
||||
|
||||
@@ -22,7 +22,6 @@ from onyx.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.background.indexing.models import IndexAttemptErrorPydantic
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.connector_credential_pair import add_credential_to_connector
|
||||
@@ -40,9 +39,7 @@ from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import count_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import count_index_attempts_for_connector
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from onyx.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from onyx.db.models import SearchSettings
|
||||
@@ -549,47 +546,6 @@ def get_docs_sync_status(
|
||||
return [DocumentSyncStatus.from_model(doc) for doc in all_docs_for_cc_pair]
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/errors")
|
||||
def get_cc_pair_indexing_errors(
|
||||
cc_pair_id: int,
|
||||
include_resolved: bool = Query(False),
|
||||
page: int = Query(0, ge=0),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[IndexAttemptErrorPydantic]:
|
||||
"""Gives back all errors for a given CC Pair. Allows pagination based on page and page_size params.
|
||||
|
||||
Args:
|
||||
cc_pair_id: ID of the connector-credential pair to get errors for
|
||||
include_resolved: Whether to include resolved errors in the results
|
||||
page: Page number for pagination, starting at 0
|
||||
page_size: Number of errors to return per page
|
||||
_: Current user, must be curator or admin
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
Paginated list of indexing errors for the CC pair.
|
||||
"""
|
||||
total_count = count_index_attempt_errors_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
unresolved_only=not include_resolved,
|
||||
)
|
||||
|
||||
index_attempt_errors = get_index_attempt_errors_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
unresolved_only=not include_resolved,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return PaginatedReturn(
|
||||
items=[IndexAttemptErrorPydantic.from_model(e) for e in index_attempt_errors],
|
||||
total_items=total_count,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/connector/{connector_id}/credential/{credential_id}")
|
||||
def associate_credential_to_connector(
|
||||
connector_id: int,
|
||||
|
||||
@@ -22,7 +22,6 @@ from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.configs.app_configs import ENABLED_CONNECTOR_TYPES
|
||||
from onyx.configs.app_configs import MOCK_CONNECTOR_FILE_PATH
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
@@ -614,16 +613,6 @@ def get_connector_indexing_status(
|
||||
) -> list[ConnectorIndexingStatus]:
|
||||
indexing_statuses: list[ConnectorIndexingStatus] = []
|
||||
|
||||
if MOCK_CONNECTOR_FILE_PATH:
|
||||
import json
|
||||
|
||||
with open(MOCK_CONNECTOR_FILE_PATH, "r") as f:
|
||||
raw_data = json.load(f)
|
||||
connector_indexing_statuses = [
|
||||
ConnectorIndexingStatus(**status) for status in raw_data
|
||||
]
|
||||
return connector_indexing_statuses
|
||||
|
||||
# NOTE: If the connector is deleting behind the scenes,
|
||||
# accessing cc_pairs can be inconsistent and members like
|
||||
# connector or credential may be None.
|
||||
|
||||
23
backend/onyx/server/documents/indexing.py
Normal file
23
backend/onyx/server/documents/indexing.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.index_attempt import (
|
||||
get_index_attempt_errors,
|
||||
)
|
||||
from onyx.db.models import User
|
||||
from onyx.server.documents.models import IndexAttemptError
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
||||
@router.get("/admin/indexing-errors/{index_attempt_id}")
|
||||
def get_indexing_errors(
|
||||
index_attempt_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[IndexAttemptError]:
|
||||
indexing_errors = get_index_attempt_errors(index_attempt_id, db_session)
|
||||
return [IndexAttemptError.from_db_model(e) for e in indexing_errors]
|
||||
@@ -8,9 +8,9 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from ee.onyx.server.query_history.models import ChatSessionMinimal
|
||||
from onyx.background.indexing.models import IndexAttemptErrorPydantic
|
||||
from onyx.configs.app_configs import MASK_CREDENTIAL_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import DocumentErrorSummary
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -19,6 +19,7 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Document as DbDocument
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError as DbIndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import TaskStatus
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
@@ -149,7 +150,6 @@ class CredentialSnapshot(CredentialBase):
|
||||
class IndexAttemptSnapshot(BaseModel):
|
||||
id: int
|
||||
status: IndexingStatus | None
|
||||
from_beginning: bool
|
||||
new_docs_indexed: int # only includes completely new docs
|
||||
total_docs_indexed: int # includes docs that are updated
|
||||
docs_removed_from_index: int
|
||||
@@ -166,7 +166,6 @@ class IndexAttemptSnapshot(BaseModel):
|
||||
return IndexAttemptSnapshot(
|
||||
id=index_attempt.id,
|
||||
status=index_attempt.status,
|
||||
from_beginning=index_attempt.from_beginning,
|
||||
new_docs_indexed=index_attempt.new_docs_indexed or 0,
|
||||
total_docs_indexed=index_attempt.total_docs_indexed or 0,
|
||||
docs_removed_from_index=index_attempt.docs_removed_from_index or 0,
|
||||
@@ -182,6 +181,31 @@ class IndexAttemptSnapshot(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class IndexAttemptError(BaseModel):
|
||||
id: int
|
||||
index_attempt_id: int | None
|
||||
batch_number: int | None
|
||||
doc_summaries: list[DocumentErrorSummary]
|
||||
error_msg: str | None
|
||||
traceback: str | None
|
||||
time_created: str
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, error: DbIndexAttemptError) -> "IndexAttemptError":
|
||||
doc_summaries = [
|
||||
DocumentErrorSummary.from_dict(summary) for summary in error.doc_summaries
|
||||
]
|
||||
return IndexAttemptError(
|
||||
id=error.id,
|
||||
index_attempt_id=error.index_attempt_id,
|
||||
batch_number=error.batch,
|
||||
doc_summaries=doc_summaries,
|
||||
error_msg=error.error_msg,
|
||||
traceback=error.traceback,
|
||||
time_created=error.time_created.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
# These are the types currently supported by the pagination hook
|
||||
# More api endpoints can be refactored and be added here for use with the pagination hook
|
||||
PaginatedType = TypeVar(
|
||||
@@ -190,7 +214,6 @@ PaginatedType = TypeVar(
|
||||
FullUserSnapshot,
|
||||
InvitedUserSnapshot,
|
||||
ChatSessionMinimal,
|
||||
IndexAttemptErrorPydantic,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -213,6 +213,8 @@ def get_chat_session(
|
||||
# we need the tool call objs anyways, so just fetch them in a single call
|
||||
prefetch_tool_calls=True,
|
||||
)
|
||||
for message in session_messages:
|
||||
translate_db_message_to_chat_message_detail(message)
|
||||
|
||||
return ChatSessionDetailResponse(
|
||||
chat_session_id=session_id,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -18,10 +17,6 @@ def load_settings() -> Settings:
|
||||
settings = (
|
||||
Settings.model_validate(stored_settings) if stored_settings else Settings()
|
||||
)
|
||||
except KvKeyNotFoundError:
|
||||
# Default to empty settings if no settings have been set yet
|
||||
logger.debug(f"No settings found in KV store for key: {KV_SETTINGS_KEY}")
|
||||
settings = Settings()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading settings from KV store: {str(e)}")
|
||||
settings = Settings()
|
||||
|
||||
@@ -58,7 +58,6 @@ SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
|
||||
SEARCH_DOC_CONTENT_ID = "search_doc_content"
|
||||
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
|
||||
SEARCH_EVALUATION_ID = "llm_doc_eval"
|
||||
QUERY_FIELD = "query"
|
||||
|
||||
|
||||
class SearchResponseSummary(SearchQueryInfo):
|
||||
@@ -180,12 +179,12 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
QUERY_FIELD: {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "What to search for",
|
||||
},
|
||||
},
|
||||
"required": [QUERY_FIELD],
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -224,7 +223,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
rephrased_query = history_based_query_rephrase(
|
||||
query=query, history=history, llm=llm
|
||||
)
|
||||
return {QUERY_FIELD: rephrased_query}
|
||||
return {"query": rephrased_query}
|
||||
|
||||
"""Actual tool execution"""
|
||||
|
||||
@@ -280,7 +279,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
def run(
|
||||
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, llm_kwargs[QUERY_FIELD])
|
||||
query = cast(str, llm_kwargs["query"])
|
||||
force_no_rerank = False
|
||||
alternate_db_session = None
|
||||
retrieved_sections_callback = None
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
import sys
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T", dict, list, tuple, set, frozenset)
|
||||
|
||||
|
||||
def deep_getsizeof(obj: T, seen: set[int] | None = None) -> int:
|
||||
"""Recursively sum size of objects, handling circular references."""
|
||||
if seen is None:
|
||||
seen = set()
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in seen:
|
||||
return 0 # Prevent infinite recursion for circular references
|
||||
|
||||
seen.add(obj_id)
|
||||
size = sys.getsizeof(obj)
|
||||
|
||||
if isinstance(obj, dict):
|
||||
size += sum(
|
||||
deep_getsizeof(k, seen) + deep_getsizeof(v, seen) for k, v in obj.items()
|
||||
)
|
||||
elif isinstance(obj, (list, tuple, set, frozenset)):
|
||||
size += sum(deep_getsizeof(i, seen) for i in obj)
|
||||
|
||||
return size
|
||||
@@ -1,4 +1,3 @@
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import as_completed
|
||||
@@ -14,10 +13,6 @@ logger = setup_logger()
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
||||
# is not safe, update this comment.
|
||||
def run_functions_tuples_in_parallel(
|
||||
functions_with_args: list[tuple[Callable, tuple]],
|
||||
allow_failures: bool = False,
|
||||
@@ -83,10 +78,6 @@ class FunctionCall(Generic[R]):
|
||||
return self.func(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
||||
# is not safe, update this comment.
|
||||
def run_functions_in_parallel(
|
||||
function_calls: list[FunctionCall],
|
||||
allow_failures: bool = False,
|
||||
@@ -118,49 +109,3 @@ def run_functions_in_parallel(
|
||||
raise
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class TimeoutThread(threading.Thread):
|
||||
def __init__(
|
||||
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
):
|
||||
super().__init__()
|
||||
self.timeout = timeout
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.exception: Exception | None = None
|
||||
|
||||
def run(self) -> None:
|
||||
try:
|
||||
self.result = self.func(*self.args, **self.kwargs)
|
||||
except Exception as e:
|
||||
self.exception = e
|
||||
|
||||
def end(self) -> None:
|
||||
raise TimeoutError(
|
||||
f"Function {self.func.__name__} timed out after {self.timeout} seconds"
|
||||
)
|
||||
|
||||
|
||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
||||
# is not safe, update this comment.
|
||||
def run_with_timeout(
|
||||
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
) -> R:
|
||||
"""
|
||||
Executes a function with a timeout. If the function doesn't complete within the specified
|
||||
timeout, raises TimeoutError.
|
||||
"""
|
||||
task = TimeoutThread(timeout, func, *args, **kwargs)
|
||||
task.start()
|
||||
task.join(timeout)
|
||||
|
||||
if task.exception is not None:
|
||||
raise task.exception
|
||||
if task.is_alive():
|
||||
task.end()
|
||||
|
||||
return task.result
|
||||
|
||||
@@ -34,8 +34,8 @@ langchain-core==0.3.24
|
||||
langchain-openai==0.2.9
|
||||
langchain-text-splitters==0.3.2
|
||||
langchainhub==0.1.21
|
||||
langgraph==0.2.72
|
||||
langgraph-checkpoint==2.0.13
|
||||
langgraph==0.2.59
|
||||
langgraph-checkpoint==2.0.5
|
||||
langgraph-sdk==0.1.44
|
||||
litellm==1.60.2
|
||||
lxml==5.3.0
|
||||
|
||||
@@ -42,7 +42,7 @@ def run_jobs() -> None:
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
]
|
||||
|
||||
cmd_worker_heavy = [
|
||||
|
||||
@@ -33,7 +33,7 @@ stopasgroup=true
|
||||
command=celery -A onyx.background.celery.versioned_apps.light worker
|
||||
--loglevel=INFO
|
||||
--hostname=light@%%n
|
||||
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup
|
||||
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert
|
||||
stdout_logfile=/var/log/celery_worker_light.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.gitbook.connector import GitbookConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gitbook_connector() -> GitbookConnector:
|
||||
connector = GitbookConnector(
|
||||
space_id=os.environ["GITBOOK_SPACE_ID"],
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"gitbook_api_key": os.environ["GITBOOK_API_KEY"],
|
||||
}
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
def test_gitbook_connector_basic(gitbook_connector: GitbookConnector) -> None:
|
||||
doc_batch_generator = gitbook_connector.load_from_state()
|
||||
|
||||
# Get first batch of documents
|
||||
doc_batch = next(doc_batch_generator)
|
||||
assert len(doc_batch) > 0
|
||||
|
||||
# Verify first document structure
|
||||
doc = doc_batch[0]
|
||||
|
||||
# Basic document properties
|
||||
assert doc.id.startswith("gitbook-")
|
||||
assert doc.semantic_identifier == "Acme Corp Internal Handbook"
|
||||
assert doc.source == DocumentSource.GITBOOK
|
||||
|
||||
# Metadata checks
|
||||
assert "path" in doc.metadata
|
||||
assert "type" in doc.metadata
|
||||
assert "kind" in doc.metadata
|
||||
|
||||
# Section checks
|
||||
assert len(doc.sections) == 1
|
||||
section = doc.sections[0]
|
||||
|
||||
# Content specific checks
|
||||
content = section.text
|
||||
|
||||
# Check for specific content elements
|
||||
assert "* Fruit Shopping List:" in content
|
||||
assert "> test quote it doesn't mean anything" in content
|
||||
|
||||
# Check headings
|
||||
assert "# Heading 1" in content
|
||||
assert "## Heading 2" in content
|
||||
assert "### Heading 3" in content
|
||||
|
||||
# Check task list
|
||||
assert "- [ ] Uncompleted Task" in content
|
||||
assert "- [x] Completed Task" in content
|
||||
|
||||
# Check table content
|
||||
assert "| ethereum | 10 | 3000 |" in content
|
||||
assert "| bitcoin | 2 | 98000 |" in content
|
||||
|
||||
# Check paragraph content
|
||||
assert "New York City comprises 5 boroughs" in content
|
||||
assert "Empire State Building" in content
|
||||
|
||||
# Check code block (just verify presence of some unique code elements)
|
||||
assert "function fizzBuzz(n)" in content
|
||||
assert 'res.push("FizzBuzz")' in content
|
||||
|
||||
assert section.link # Should have a URL
|
||||
|
||||
# Time-based polling test
|
||||
current_time = time.time()
|
||||
poll_docs = gitbook_connector.poll_source(0, current_time)
|
||||
poll_batch = next(poll_docs)
|
||||
assert len(poll_batch) > 0
|
||||
@@ -1,15 +1,14 @@
|
||||
import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
def test_create_chat_session_and_send_messages() -> None:
|
||||
def test_create_chat_session_and_send_messages(db_session: Session) -> None:
|
||||
# Create a test user
|
||||
with get_session_context_manager() as db_session:
|
||||
test_user = User(email="test@example.com", hashed_password="dummy_hash")
|
||||
db_session.add(test_user)
|
||||
db_session.commit()
|
||||
test_user = User(email="test@example.com", hashed_password="dummy_hash")
|
||||
db_session.add(test_user)
|
||||
db_session.commit()
|
||||
|
||||
base_url = "http://localhost:8080" # Adjust this to your API's base URL
|
||||
headers = {"Authorization": f"Bearer {test_user.id}"}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user