Compare commits

..

15 Commits

Author SHA1 Message Date
Weves
9b169350a9 Switch to monotonic 2025-03-07 19:04:47 -08:00
Weves
c1dbb073d0 Small tweaks 2025-03-07 15:53:22 -08:00
Weves
39bfc6ae16 Add basic memory logging 2025-03-07 15:46:14 -08:00
rkuo-danswer
9217243e3e Bugfix/query history notes (#4204)
* early work in progress

* rename utility script

* move actual data seeding to a shareable function

* add test

* make the test pass with the fix

* fix comment

* slight improvements and notes to query history and seeding

* update test

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-07 19:52:30 +00:00
rkuo-danswer
61ccba82a9 light worker needs to discover some indexing tasks (#4209)
* light worker needs to discover some indexing tasks

* fix formatting

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-07 11:52:09 -08:00
Weves
9e8eba23c3 Fix frozen model issue 2025-03-07 09:05:43 -08:00
evan-danswer
0c29743538 use max_tokens to do better rate limit handling (#4224)
* use max_tokens to do better rate limit handling

* fix unti tests

* address greptile comment, thanks greptile
2025-03-06 18:12:05 -08:00
pablonyx
08b2421947 fix 2025-03-06 17:30:31 -08:00
pablonyx
ed518563db minor typing update 2025-03-06 17:02:39 -08:00
pablonyx
a32f7dc936 Fix Connector tests (confluence) (#4221) 2025-03-06 17:00:01 -08:00
rkuo-danswer
798e10c52f revert to always building model server (#4213)
* revert to always building model server

* fix just in case

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-06 23:49:45 +00:00
pablonyx
bf4983e35a Ensure consistent UX (#4222)
* ux consistent

* nit

* Update web/src/app/admin/configuration/llm/interfaces.ts

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-03-06 23:13:32 +00:00
evan-danswer
b7da91e3ae improved basic search latency (#4186)
* improved basic search latency

* address PR comments + minor cleanup
2025-03-06 22:22:59 +00:00
Weves
29382656fc Stop trying a million times for the user validity check 2025-03-06 15:35:49 -08:00
pablonyx
7d6db8d500 Comma separated list for Github repos (#4199) 2025-03-06 14:46:57 -08:00
63 changed files with 813 additions and 256 deletions

View File

@@ -12,29 +12,40 @@ env:
BUILDKIT_PROGRESS: plain
jobs:
# 1) Preliminary job to check if the changed files are relevant
# Bypassing this for now as the idea of not building is glitching
# releases and builds that depends on everything being tagged in docker
# 1) Preliminary job to check if the changed files are relevant
# check_model_server_changes:
# runs-on: ubuntu-latest
# outputs:
# changed: ${{ steps.check.outputs.changed }}
# steps:
# - name: Checkout code
# uses: actions/checkout@v4
#
# - name: Check if relevant files changed
# id: check
# run: |
# # Default to "false"
# echo "changed=false" >> $GITHUB_OUTPUT
#
# # Compare the previous commit (github.event.before) to the current one (github.sha)
# # If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# # set changed=true
# if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
# | grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
# echo "changed=true" >> $GITHUB_OUTPUT
# fi
check_model_server_changes:
runs-on: ubuntu-latest
outputs:
changed: ${{ steps.check.outputs.changed }}
changed: "true"
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check if relevant files changed
id: check
run: |
# Default to "false"
echo "changed=false" >> $GITHUB_OUTPUT
# Compare the previous commit (github.event.before) to the current one (github.sha)
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# set changed=true
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
echo "changed=true" >> $GITHUB_OUTPUT
fi
- name: Bypass check and set output
run: echo "changed=true" >> $GITHUB_OUTPUT
build-amd64:
needs: [check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'

View File

@@ -1,6 +1,7 @@
name: Connector Tests
on:
merge_group:
pull_request:
branches: [main]
schedule:
@@ -51,7 +52,7 @@ env:
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend
@@ -76,7 +77,7 @@ jobs:
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
playwright install chromium
playwright install-deps chromium
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors

View File

@@ -27,6 +27,8 @@ def get_empty_chat_messages_entries__paginated(
first element is the most recent timestamp out of the sessions iterated
- this timestamp can be used to paginate forward in time
second element is a list of messages belonging to all the sessions iterated
Only messages of type USER are returned
"""
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=period[0],

View File

@@ -48,10 +48,15 @@ def fetch_and_process_chat_session_history(
feedback_type: QAFeedbackType | None,
limit: int | None = 500,
) -> list[ChatSessionSnapshot]:
# observed to be slow a scale of 8192 sessions and 4 messages per session
# this is a little slow (5 seconds)
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=start, end=end, db_session=db_session, limit=limit
)
# this is VERY slow (80 seconds) due to create_chat_chain being called
# for each session. Needs optimizing.
chat_session_snapshots = [
snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
for chat_session in chat_sessions
@@ -246,6 +251,8 @@ def get_query_history_as_csv(
detail="Query history has been disabled by the administrator.",
)
# this call is very expensive and is timing out via endpoint
# TODO: optimize call and/or generate via background task
complete_chat_session_history = fetch_and_process_chat_session_history(
db_session=db_session,
start=start or datetime.fromtimestamp(0, tz=timezone.utc),

View File

@@ -31,6 +31,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
@@ -92,6 +93,7 @@ def check_sub_answer(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
quality_str: str = cast(str, response.content)

View File

@@ -46,6 +46,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
@@ -119,6 +120,7 @@ def generate_sub_answer(
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content

View File

@@ -43,6 +43,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
@@ -62,6 +63,7 @@ from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
@@ -153,8 +155,9 @@ def generate_initial_answer(
)
for tool_response in yield_search_responses(
query=question,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -278,6 +281,9 @@ def generate_initial_answer(
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content

View File

@@ -34,6 +34,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
@@ -141,6 +142,7 @@ def decompose_orig_question(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),

View File

@@ -33,6 +33,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
@@ -112,6 +113,7 @@ def compare_answers(
model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
except (LLMTimeoutError, TimeoutError):

View File

@@ -43,6 +43,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
)
@@ -144,6 +145,7 @@ def create_refined_sub_questions(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),

View File

@@ -50,13 +50,7 @@ def decide_refinement_need(
)
]
if graph_config.behavior.allow_refinement:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=decision,
log_messages=log_messages,
)
else:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=False,
log_messages=log_messages,
)
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=graph_config.behavior.allow_refinement and decision,
log_messages=log_messages,
)

View File

@@ -21,6 +21,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
)
@@ -96,6 +97,7 @@ def extract_entities_terms(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
max_tokens=AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (

View File

@@ -46,6 +46,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
@@ -68,6 +69,8 @@ from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
@@ -179,8 +182,9 @@ def generate_validate_refined_answer(
)
for tool_response in yield_search_responses(
query=question,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -302,7 +306,11 @@ def generate_validate_refined_answer(
def stream_refined_answer() -> list[str]:
for message in model.stream(
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -409,6 +417,7 @@ def generate_validate_refined_answer(
validation_model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
refined_answer_quality = binary_string_test_after_answer_separator(
text=cast(str, validation_response.content),

View File

@@ -13,7 +13,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import IndexFilters
from onyx.tools.models import SearchQueryInfo
from onyx.utils.logger import setup_logger
@@ -144,8 +143,6 @@ def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
if result.query_info is not None:
query_info = result.query_info
break
return query_info or SearchQueryInfo(
predicted_search=None,
final_filters=IndexFilters(access_control_list=None),
recency_bias_multiplier=1.0,
)
assert query_info is not None, "must have query info"
return query_info

View File

@@ -33,6 +33,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUERY_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
)
@@ -96,6 +97,7 @@ def expand_queries(
model.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUERY_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)

View File

@@ -56,8 +56,9 @@ def format_results(
relevance_list = relevance_from_docs(reranked_documents)
for tool_response in yield_search_responses(
query=state.question,
reranked_sections=state.retrieved_documents,
final_context_sections=reranked_documents,
get_retrieved_sections=lambda: reranked_documents,
get_reranked_sections=lambda: state.retrieved_documents,
get_final_context_sections=lambda: reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,

View File

@@ -91,7 +91,7 @@ def retrieve_documents(
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
if AGENT_RETRIEVAL_STATS:
pre_rerank_docs = callback_container[0]
pre_rerank_docs = callback_container[0] if callback_container else []
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,

View File

@@ -25,6 +25,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
from onyx.llm.chat_llm import LLMRateLimitError
@@ -93,6 +94,7 @@ def verify_documents(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
assert isinstance(response.content, str)

View File

@@ -44,7 +44,9 @@ def call_tool(
tool = tool_choice.tool
tool_args = tool_choice.tool_args
tool_id = tool_choice.id
tool_runner = ToolRunner(tool, tool_args)
tool_runner = ToolRunner(
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
)
tool_kickoff = tool_runner.kickoff()
emit_packet(tool_kickoff, writer)

View File

@@ -15,8 +15,17 @@ from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import TimeoutThread
from onyx.utils.threadpool_concurrency import wait_on_background
from onyx.utils.timing import log_function_time
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -25,6 +34,7 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
@log_function_time(print_only=True)
def choose_tool(
state: ToolChoiceState,
config: RunnableConfig,
@@ -37,6 +47,31 @@ def choose_tool(
should_stream_answer = state.should_stream_answer
agent_config = cast(GraphConfig, config["metadata"]["config"])
force_use_tool = agent_config.tooling.force_use_tool
embedding_thread: TimeoutThread[Embedding] | None = None
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
override_kwargs: SearchToolOverrideKwargs | None = None
if (
not agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and (
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name
)
):
override_kwargs = SearchToolOverrideKwargs()
# Run in a background thread to avoid blocking the main thread
embedding_thread = run_in_background(
get_query_embedding,
agent_config.inputs.search_request.query,
agent_config.persistence.db_session,
)
keyword_thread = run_in_background(
query_analysis,
agent_config.inputs.search_request.query,
)
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
@@ -47,7 +82,6 @@ def choose_tool(
tools = [
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
]
force_use_tool = agent_config.tooling.force_use_tool
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
@@ -71,11 +105,22 @@ def choose_tool(
# If we have a tool and tool args, we are ready to request a tool call.
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
if embedding_thread and tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
search_tool_override_kwargs=override_kwargs,
),
)
@@ -153,10 +198,22 @@ def choose_tool(
logger.debug(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
if embedding_thread and selected_tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and selected_tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
search_tool_override_kwargs=override_kwargs,
),
)

View File

@@ -9,18 +9,23 @@ from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContexts
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_DOC_CONTENT_ID,
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
@log_function_time(print_only=True)
def basic_use_tool_response(
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BasicOutput:
@@ -50,11 +55,13 @@ def basic_use_tool_response(
for yield_item in tool_call_responses:
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_search_results = cast(list[LlmDoc], yield_item.response)
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
search_contexts = cast(OnyxContexts, yield_item.response).contexts
for doc in search_contexts:
if doc.document_id not in initial_search_results:
initial_search_results.append(doc)
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
search_response_summary = cast(SearchResponseSummary, yield_item.response)
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_results:
initial_search_results.append(
context_from_inference_section(section)
)
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@@ -2,6 +2,7 @@ from pydantic import BaseModel
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
@@ -35,6 +36,7 @@ class ToolChoice(BaseModel):
tool: Tool
tool_args: dict
id: str | None
search_tool_override_kwargs: SearchToolOverrideKwargs | None = None
class Config:
arbitrary_types_allowed = True

View File

@@ -13,6 +13,11 @@ AGENT_NEGATIVE_VALUE_STR = "no"
AGENT_ANSWER_SEPARATOR = "Answer:"
EMBEDDING_KEY = "embedding"
IS_KEYWORD_KEY = "is_keyword"
KEYWORDS_KEY = "keywords"
class AgentLLMErrorType(str, Enum):
TIMEOUT = "timeout"
RATE_LIMIT = "rate_limit"

View File

@@ -42,6 +42,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_HISTORY_SUMMARY
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
)
@@ -61,6 +62,7 @@ from onyx.db.persona import Persona
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.prompts.agent_search import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
@@ -402,6 +404,7 @@ def summarize_history(
llm.invoke,
history_context_prompt,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
max_tokens=AGENT_MAX_TOKENS_HISTORY_SUMMARY,
)
except (LLMTimeoutError, TimeoutError):
logger.error("LLM Timeout Error - summarize history")
@@ -505,3 +508,9 @@ def get_deduplicated_structured_subquestion_documents(
cited_documents=dedup_inference_section_list(cited_docs),
context_documents=dedup_inference_section_list(context_docs),
)
def _should_restrict_tokens(llm_config: LLMConfig) -> bool:
return not (
llm_config.model_provider == "openai" and llm_config.model_name.startswith("o")
)

View File

@@ -111,5 +111,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.indexing",
]
)

View File

@@ -0,0 +1,60 @@
# backend/onyx/background/celery/memory_monitoring.py
import logging
import os
from logging.handlers import RotatingFileHandler
import psutil
from onyx.utils.logger import setup_logger
# Regular application logger
logger = setup_logger()
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
# Ensure log directory exists
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
# Create a dedicated logger for memory monitoring
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.setLevel(logging.INFO)
# Create a rotating file handler
memory_handler = RotatingFileHandler(
MEMORY_LOG_FILE, maxBytes=MEMORY_LOG_MAX_BYTES, backupCount=MEMORY_LOG_BACKUP_COUNT
)
# Create a formatter that includes all relevant information
memory_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
memory_handler.setFormatter(memory_formatter)
memory_logger.addHandler(memory_handler)
def emit_process_memory(
pid: int, process_name: str, additional_metadata: dict[str, str | int]
) -> None:
try:
process = psutil.Process(pid)
memory_info = process.memory_info()
cpu_percent = process.cpu_percent(interval=0.1)
# Build metadata string from additional_metadata dictionary
metadata_str = " ".join(
[f"{key}={value}" for key, value in additional_metadata.items()]
)
metadata_str = f" {metadata_str}" if metadata_str else ""
memory_logger.info(
f"PROCESS_MEMORY process_name={process_name} pid={pid} "
f"rss_mb={memory_info.rss / (1024 * 1024):.2f} "
f"vms_mb={memory_info.vms / (1024 * 1024):.2f} "
f"cpu={cpu_percent:.2f}{metadata_str}"
)
except Exception:
logger.exception("Error monitoring process memory.")

View File

@@ -23,6 +23,7 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import should_index
@@ -984,6 +985,9 @@ def connector_indexing_proxy_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
@@ -1024,6 +1028,23 @@ def connector_indexing_proxy_task(
job.release()
break
# log the memory usage for tracking down memory leaks / connector-specific memory issues
pid = job.process.pid
if pid is not None:
# Only emit memory info once per minute (60 seconds)
current_time = time.monotonic()
if current_time - last_memory_emit_time >= 60.0:
emit_process_memory(
pid,
"indexing_worker",
{
"cc_pair_id": cc_pair_id,
"search_settings_id": search_settings_id,
"index_attempt_id": index_attempt_id,
},
)
last_memory_emit_time = current_time
# if a termination signal is detected, break (exit point will clean up)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
@@ -1170,6 +1191,7 @@ def connector_indexing_proxy_task(
return
# primary
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
soft_time_limit=300,
@@ -1217,6 +1239,7 @@ def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
)
# light worker
@shared_task(
name=OnyxCeleryTask.CLEANUP_CHECKPOINT,
bind=True,

View File

@@ -15,6 +15,8 @@ from onyx.chat.stream_processing.answer_response_handler import (
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
# This is Legacy code that is not used anymore.
# It is kept here for reference.
class LLMResponseHandlerManager:
"""
This class is responsible for postprocessing the LLM response stream.

View File

@@ -90,97 +90,97 @@ class CitationProcessor:
next(group for group in citation.groups() if group is not None)
)
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
final_citation_num = self.final_order_mapping[
if not (1 <= numerical_value <= self.max_citation_num):
continue
context_llm_doc = self.context_docs[numerical_value - 1]
final_citation_num = self.final_order_mapping[
context_llm_doc.document_id
]
if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)
citation_order_idx = self.citation_order.index(final_citation_num) + 1
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
context_llm_doc.document_id
]
if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)
citation_order_idx = (
self.citation_order.index(final_citation_num) + 1
else:
displayed_citation_num = final_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
context_llm_doc.document_id
]
else:
displayed_citation_num = final_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
# Skip consecutive citations of the same work
if final_citation_num in self.current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
logger.warning(
f"Manual LLM citation didn't properly cite documents {e}"
)
else:
logger.warning(
"Manual LLM citation wasn't able to close brackets"
)
continue
link = context_llm_doc.link
self.past_cite_count = len(self.llm_out)
self.current_citations.append(final_citation_num)
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
# Skip consecutive citations of the same work
if final_citation_num in self.current_citations:
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
else:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
last_citation_end = end + length_to_add
# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
logger.warning(
f"Manual LLM citation didn't properly cite documents {e}"
)
else:
logger.warning(
"Manual LLM citation wasn't able to close brackets"
)
continue
link = context_llm_doc.link
self.past_cite_count = len(self.llm_out)
self.current_citations.append(final_citation_num)
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
else:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
last_citation_end = end + length_to_add
if last_citation_end > 0:
result += self.curr_segment[:last_citation_end]

View File

@@ -217,20 +217,20 @@ AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 6 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 40 # in seconds
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 5 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 10 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
@@ -243,13 +243,13 @@ AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 5 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 15 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 30 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 45 # in seconds
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
@@ -333,4 +333,45 @@ AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION
)
AGENT_DEFAULT_MAX_TOKENS_VALIDATION = 4
AGENT_MAX_TOKENS_VALIDATION = int(
os.environ.get("AGENT_MAX_TOKENS_VALIDATION") or AGENT_DEFAULT_MAX_TOKENS_VALIDATION
)
AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION = 256
AGENT_MAX_TOKENS_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_SUBANSWER_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION = 1024
AGENT_MAX_TOKENS_ANSWER_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_ANSWER_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION = 256
AGENT_MAX_TOKENS_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = 1024
AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
)
AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION = 64
AGENT_MAX_TOKENS_SUBQUERY_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_SUBQUERY_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY = 128
AGENT_MAX_TOKENS_HISTORY_SUMMARY = int(
os.environ.get("AGENT_MAX_TOKENS_HISTORY_SUMMARY")
or AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY
)
GRAPH_VERSION_NAME: str = "a"

View File

@@ -66,9 +66,6 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
_SLIM_DOC_BATCH_SIZE = 5000
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"png",
"jpg",
"jpeg",
"gif",
"mp4",
"mov",
@@ -305,7 +302,9 @@ class ConfluenceConnector(
# Create the document
return Document(
id=build_confluence_document_id(self.wiki_base, page_id, self.is_cloud),
id=build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
),
sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title,
@@ -376,7 +375,7 @@ class ConfluenceConnector(
content_text, file_storage_name = response
object_url = build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
)
if content_text:

View File

@@ -316,7 +316,9 @@ class GoogleDriveConnector(
# validate that the user has access to the drive APIs by performing a simple
# request and checking for a 401
try:
retry_builder()(get_root_folder_id)(drive_service)
# default is ~17mins of retries, don't do that here for cases so we don't
# waste 17mins everytime we run into a user without access to drive APIs
retry_builder(tries=3, delay=1)(get_root_folder_id)(drive_service)
except HttpError as e:
if e.status_code == 401:
# fail gracefully, let the other impersonations continue

View File

@@ -16,7 +16,7 @@ from onyx.db.models import SearchSettings
from onyx.indexing.models import BaseChunk
from onyx.indexing.models import IndexingSetting
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
MAX_METRICS_CONTENT = (
200 # Just need enough characters to identify where in the doc the chunk is
@@ -151,6 +151,10 @@ class SearchRequest(ChunkContext):
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
model_config = ConfigDict(arbitrary_types_allowed=True)
precomputed_query_embedding: Embedding | None = None
precomputed_is_keyword: bool | None = None
precomputed_keywords: list[str] | None = None
class SearchQuery(ChunkContext):
"Processed Request that is directly passed to the SearchPipeline"
@@ -175,6 +179,8 @@ class SearchQuery(ChunkContext):
offset: int = 0
model_config = ConfigDict(frozen=True)
precomputed_query_embedding: Embedding | None = None
class RetrievalDetails(ChunkContext):
# Use LLM to determine whether to do a retrieval or only rely on existing history

View File

@@ -331,6 +331,14 @@ class SearchPipeline:
self._retrieved_sections = expanded_inference_sections
return expanded_inference_sections
@property
def retrieved_sections(self) -> list[InferenceSection]:
if self._retrieved_sections is not None:
return self._retrieved_sections
self._retrieved_sections = self._get_sections()
return self._retrieved_sections
@property
def reranked_sections(self) -> list[InferenceSection]:
"""Reranking is always done at the chunk level since section merging could create arbitrarily
@@ -343,7 +351,7 @@ class SearchPipeline:
if self._reranked_sections is not None:
return self._reranked_sections
retrieved_sections = self._get_sections()
retrieved_sections = self.retrieved_sections
if self.retrieved_sections_callback is not None:
self.retrieved_sections_callback(retrieved_sections)

View File

@@ -117,8 +117,12 @@ def retrieval_preprocessing(
else None
)
# Sometimes this is pre-computed in parallel with other heavy tasks to improve
# latency, and in that case we don't need to run the model again
run_query_analysis = (
None if skip_query_analysis else FunctionCall(query_analysis, (query,), {})
None
if (skip_query_analysis or search_request.precomputed_is_keyword is not None)
else FunctionCall(query_analysis, (query,), {})
)
functions_to_run = [
@@ -143,11 +147,12 @@ def retrieval_preprocessing(
# The extracted keywords right now are not very reliable, not using for now
# Can maybe use for highlighting
is_keyword, extracted_keywords = (
parallel_results[run_query_analysis.result_id]
if run_query_analysis
else (False, None)
)
is_keyword, _extracted_keywords = False, None
if search_request.precomputed_is_keyword is not None:
is_keyword = search_request.precomputed_is_keyword
_extracted_keywords = search_request.precomputed_keywords
elif run_query_analysis:
is_keyword, _extracted_keywords = parallel_results[run_query_analysis.result_id]
all_query_terms = query.split()
processed_keywords = (
@@ -247,4 +252,5 @@ def retrieval_preprocessing(
chunks_above=chunks_above,
chunks_below=chunks_below,
full_doc=search_request.full_doc,
precomputed_query_embedding=search_request.precomputed_query_embedding,
)

View File

@@ -31,7 +31,7 @@ from onyx.utils.timing import log_function_time
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -109,6 +109,20 @@ def combine_retrieval_results(
return sorted_chunks
def get_query_embedding(query: str, db_session: Session) -> Embedding:
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode([query], text_type=EmbedTextType.QUERY)[0]
return query_embedding
@log_function_time(print_only=True)
def doc_index_retrieval(
query: SearchQuery,
@@ -121,17 +135,10 @@ def doc_index_retrieval(
from the large chunks to the referenced chunks,
dedupes the chunks, and cleans the chunks.
"""
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
query_embedding = query.precomputed_query_embedding or get_query_embedding(
query.query, db_session
)
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
top_chunks = document_index.hybrid_retrieval(
query=query.query,
query_embedding=query_embedding,
@@ -249,7 +256,16 @@ def retrieve_chunks(
continue
simplified_queries.add(simplified_rephrase)
q_copy = query.copy(update={"query": rephrase}, deep=True)
q_copy = query.model_copy(
update={
"query": rephrase,
# need to recompute for each rephrase
# note that `SearchQuery` is a frozen model, so we can't update
# it below
"precomputed_query_embedding": None,
},
deep=True,
)
run_queries.append(
(
doc_index_retrieval,

View File

@@ -1,6 +1,7 @@
import random
from datetime import datetime
from datetime import timedelta
from logging import getLogger
from onyx.configs.constants import MessageType
from onyx.db.chat import create_chat_session
@@ -9,6 +10,8 @@ from onyx.db.chat import get_or_create_root_message
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChatSession
logger = getLogger(__name__)
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
"""Utility function to seed chat history for testing.
@@ -19,12 +22,18 @@ def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
the times.
"""
with get_session_with_current_tenant() as db_session:
logger.info(f"Seeding {num_sessions} sessions.")
for y in range(0, num_sessions):
create_chat_session(db_session, f"pytest_session_{y}", None, None)
# randomize all session times
logger.info(f"Seeding {num_messages} messages per session.")
rows = db_session.query(ChatSession).all()
for row in rows:
for x in range(0, len(rows)):
if x % 1024 == 0:
logger.info(f"Seeded messages for {x} sessions so far.")
row = rows[x]
row.time_created = datetime.utcnow() - timedelta(
days=random.randint(0, days)
)
@@ -34,20 +43,37 @@ def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
root_message = get_or_create_root_message(row.id, db_session)
current_message_type = MessageType.USER
parent_message = root_message
for x in range(0, num_messages):
if current_message_type == MessageType.USER:
msg = f"pytest_message_user_{x}"
else:
msg = f"pytest_message_assistant_{x}"
chat_message = create_new_chat_message(
row.id,
root_message,
f"pytest_message_{x}",
parent_message,
msg,
None,
0,
MessageType.USER,
current_message_type,
db_session,
)
chat_message.time_sent = row.time_created + timedelta(
minutes=random.randint(0, 10)
)
db_session.commit()
db_session.commit()
current_message_type = (
MessageType.ASSISTANT
if current_message_type == MessageType.USER
else MessageType.USER
)
parent_message = chat_message
db_session.commit()
logger.info(f"Seeded messages for {len(rows)} sessions. Finished.")

View File

@@ -167,7 +167,7 @@ def _convert_delta_to_message_chunk(
stop_reason: str | None = None,
) -> BaseMessageChunk:
"""Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk"""
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else None)
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else "unknown")
content = _dict.get("content") or ""
additional_kwargs = {}
if _dict.get("function_call"):
@@ -402,6 +402,7 @@ class DefaultMultiLLM(LLM):
stream: bool,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
# to a dict representation
@@ -429,6 +430,7 @@ class DefaultMultiLLM(LLM):
# model params
temperature=0,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
@@ -484,6 +486,7 @@ class DefaultMultiLLM(LLM):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
@@ -497,6 +500,7 @@ class DefaultMultiLLM(LLM):
stream=False,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
),
)
choice = response.choices[0]
@@ -515,6 +519,7 @@ class DefaultMultiLLM(LLM):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
@@ -539,6 +544,7 @@ class DefaultMultiLLM(LLM):
stream=True,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
),
)
try:

View File

@@ -82,6 +82,7 @@ class CustomModelServer(LLM):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
return self._execute(prompt)
@@ -92,5 +93,6 @@ class CustomModelServer(LLM):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
yield self._execute(prompt)

View File

@@ -91,12 +91,18 @@ class LLM(abc.ABC):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._invoke_implementation(
prompt, tools, tool_choice, structured_response_format, timeout_override
prompt,
tools,
tool_choice,
structured_response_format,
timeout_override,
max_tokens,
)
@abc.abstractmethod
@@ -107,6 +113,7 @@ class LLM(abc.ABC):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
raise NotImplementedError
@@ -117,12 +124,18 @@ class LLM(abc.ABC):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
messages = self._stream_implementation(
prompt, tools, tool_choice, structured_response_format, timeout_override
prompt,
tools,
tool_choice,
structured_response_format,
timeout_override,
max_tokens,
)
tokens = []
@@ -142,5 +155,6 @@ class LLM(abc.ABC):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection
from shared_configs.model_server_models import Embedding
class ToolResponse(BaseModel):
@@ -60,11 +61,15 @@ class SearchQueryInfo(BaseModel):
recency_bias_multiplier: float
# None indicates that the default value should be used
class SearchToolOverrideKwargs(BaseModel):
force_no_rerank: bool
alternate_db_session: Session | None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
skip_query_analysis: bool
force_no_rerank: bool | None = None
alternate_db_session: Session | None = None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None = None
skip_query_analysis: bool | None = None
precomputed_query_embedding: Embedding | None = None
precomputed_is_keyword: bool | None = None
precomputed_keywords: list[str] | None = None
class Config:
arbitrary_types_allowed = True

View File

@@ -3,6 +3,7 @@ from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import cast
from typing import TypeVar
from sqlalchemy.orm import Session
@@ -11,7 +12,6 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ContextualPruningConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece
@@ -42,6 +42,9 @@ from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
from onyx.tools.tool_implementations.search_like_tool_utils import (
build_next_prompt_for_search_like_tool,
@@ -281,16 +284,23 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
) -> Generator[ToolResponse, None, None]:
query = cast(str, llm_kwargs[QUERY_FIELD])
precomputed_query_embedding = None
precomputed_is_keyword = None
precomputed_keywords = None
force_no_rerank = False
alternate_db_session = None
retrieved_sections_callback = None
skip_query_analysis = False
if override_kwargs:
force_no_rerank = override_kwargs.force_no_rerank
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
alternate_db_session = override_kwargs.alternate_db_session
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
skip_query_analysis = override_kwargs.skip_query_analysis
skip_query_analysis = use_alt_not_None(
override_kwargs.skip_query_analysis, False
)
precomputed_query_embedding = override_kwargs.precomputed_query_embedding
precomputed_is_keyword = override_kwargs.precomputed_is_keyword
precomputed_keywords = override_kwargs.precomputed_keywords
if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
return
@@ -327,6 +337,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
if self.retrieval_options
else None
),
precomputed_query_embedding=precomputed_query_embedding,
precomputed_is_keyword=precomputed_is_keyword,
precomputed_keywords=precomputed_keywords,
),
user=self.user,
llm=self.llm,
@@ -345,8 +358,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
)
yield from yield_search_responses(
query,
search_pipeline.reranked_sections,
search_pipeline.final_context_sections,
lambda: search_pipeline.retrieved_sections,
lambda: search_pipeline.reranked_sections,
lambda: search_pipeline.final_context_sections,
search_query_info,
lambda: search_pipeline.section_relevance,
self,
@@ -383,10 +397,16 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
# SearchTool passed in to allow for access to SearchTool properties.
# We can't just call SearchTool methods in the graph because we're operating on
# the retrieved docs (reranking, deduping, etc.) after the SearchTool has run.
#
# The various inference sections are passed in as functions to allow for lazy
# evaluation. The SearchPipeline object properties that they correspond to are
# actually functions defined with @property decorators, and passing them into
# this function causes them to get evaluated immediately which is undesirable.
def yield_search_responses(
query: str,
reranked_sections: list[InferenceSection],
final_context_sections: list[InferenceSection],
get_retrieved_sections: Callable[[], list[InferenceSection]],
get_reranked_sections: Callable[[], list[InferenceSection]],
get_final_context_sections: Callable[[], list[InferenceSection]],
search_query_info: SearchQueryInfo,
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
search_tool: SearchTool,
@@ -395,7 +415,7 @@ def yield_search_responses(
id=SEARCH_RESPONSE_SUMMARY_ID,
response=SearchResponseSummary(
rephrased_query=query,
top_sections=final_context_sections,
top_sections=get_retrieved_sections(),
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=search_query_info.predicted_search,
final_filters=search_query_info.final_filters,
@@ -407,13 +427,8 @@ def yield_search_responses(
id=SEARCH_DOC_CONTENT_ID,
response=OnyxContexts(
contexts=[
OnyxContext(
content=section.combined_content,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
blurb=section.center_chunk.blurb,
)
for section in reranked_sections
context_from_inference_section(section)
for section in get_reranked_sections()
]
),
)
@@ -424,6 +439,7 @@ def yield_search_responses(
response=section_relevance,
)
final_context_sections = get_final_context_sections()
pruned_sections = prune_sections(
sections=final_context_sections,
section_relevance_list=section_relevance_list_impl(
@@ -438,3 +454,10 @@ def yield_search_responses(
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
T = TypeVar("T")
def use_alt_not_None(value: T | None, alt: T) -> T:
return value if value is not None else alt

View File

@@ -1,4 +1,5 @@
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.context.search.models import InferenceSection
from onyx.prompts.prompt_utils import clean_up_source
@@ -29,3 +30,12 @@ def section_to_dict(section: InferenceSection, section_num: int) -> dict:
"%B %d, %Y %H:%M"
)
return doc_dict
def context_from_inference_section(section: InferenceSection) -> OnyxContext:
return OnyxContext(
content=section.combined_content,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
blurb=section.center_chunk.blurb,
)

View File

@@ -1,6 +1,8 @@
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import Generic
from typing import TypeVar
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
@@ -11,10 +13,16 @@ from onyx.tools.tool import Tool
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
class ToolRunner:
def __init__(self, tool: Tool, args: dict[str, Any]):
R = TypeVar("R")
class ToolRunner(Generic[R]):
def __init__(
self, tool: Tool[R], args: dict[str, Any], override_kwargs: R | None = None
):
self.tool = tool
self.args = args
self.override_kwargs = override_kwargs
self._tool_responses: list[ToolResponse] | None = None
@@ -27,7 +35,9 @@ class ToolRunner:
return
tool_responses: list[ToolResponse] = []
for tool_response in self.tool.run(**self.args):
for tool_response in self.tool.run(
override_kwargs=self.override_kwargs, **self.args
):
yield tool_response
tool_responses.append(tool_response)

View File

@@ -118,7 +118,7 @@ def run_functions_in_parallel(
return results
class TimeoutThread(threading.Thread):
class TimeoutThread(threading.Thread, Generic[R]):
def __init__(
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
):
@@ -159,3 +159,34 @@ def run_with_timeout(
task.end()
return task.result
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
# difficult to use. It's up to the programmer to call wait_on_background on the thread after
# the code you want to run in parallel is finished. As with all python thread parallelism,
# this is only useful for I/O bound tasks.
def run_in_background(
func: Callable[..., R], *args: Any, **kwargs: Any
) -> TimeoutThread[R]:
"""
Runs a function in a background thread. Returns a TimeoutThread object that can be used
to wait for the function to finish with wait_on_background.
"""
context = contextvars.copy_context()
# Timeout not used in the non-blocking case
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
task.start()
return task
def wait_on_background(task: TimeoutThread[R]) -> R:
"""
Used in conjunction with run_in_background. blocks until the task is finished,
then returns the result of the task.
"""
task.join()
if task.exception is not None:
raise task.exception
return task.result

View File

@@ -108,6 +108,7 @@ command=tail -qF
/var/log/celery_worker_light.log
/var/log/celery_worker_heavy.log
/var/log/celery_worker_indexing.log
/var/log/celery_worker_monitoring.log
/var/log/slack_bot.log
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes = 0 # must be set to 0 when stdout_logfile=/dev/stdout

View File

@@ -41,5 +41,10 @@ def test_confluence_connector_permissions(
for slim_doc_batch in confluence_connector.retrieve_all_slim_documents():
all_slim_doc_ids.update([doc.id for doc in slim_doc_batch])
# Find IDs that are in full but not in slim
difference = all_full_doc_ids - all_slim_doc_ids
# The set of full doc IDs should be always be a subset of the slim doc IDs
assert all_full_doc_ids.issubset(all_slim_doc_ids)
assert all_full_doc_ids.issubset(
all_slim_doc_ids
), f"Full doc IDs are not a subset of slim doc IDs. Found {len(difference)} IDs in full docs but not in slim docs."

View File

@@ -25,7 +25,7 @@ from onyx.indexing.models import IndexingSetting
from onyx.setup import setup_postgres
from onyx.setup import setup_vespa
from onyx.utils.logger import setup_logger
from tests.integration.common_utils.timeout import run_with_timeout
from tests.integration.common_utils.timeout import run_with_timeout_multiproc
logger = setup_logger()
@@ -161,7 +161,7 @@ def reset_postgres(
for _ in range(NUM_TRIES):
logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})")
try:
run_with_timeout(
run_with_timeout_multiproc(
downgrade_postgres,
TIMEOUT,
kwargs={

View File

@@ -6,7 +6,9 @@ from typing import TypeVar
T = TypeVar("T")
def run_with_timeout(task: Callable[..., T], timeout: int, kwargs: dict[str, Any]) -> T:
def run_with_timeout_multiproc(
task: Callable[..., T], timeout: int, kwargs: dict[str, Any]
) -> T:
# Use multiprocessing to prevent a thread from blocking the main thread
with multiprocessing.Pool(processes=1) as pool:
async_result = pool.apply_async(task, kwds=kwargs)

View File

@@ -10,7 +10,9 @@ from onyx.db.seeding.chat_history_seeding import seed_chat_history
def test_usage_reports(reset: None) -> None:
EXPECTED_SESSIONS = 2048
MESSAGES_PER_SESSION = 4
EXPECTED_MESSAGES = EXPECTED_SESSIONS * MESSAGES_PER_SESSION
# divide by 2 because only messages of type USER are returned
EXPECTED_MESSAGES = EXPECTED_SESSIONS * MESSAGES_PER_SESSION / 2
seed_chat_history(EXPECTED_SESSIONS, MESSAGES_PER_SESSION, 90)

View File

@@ -145,6 +145,7 @@ def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None:
timeout=30,
parallel_tool_calls=False,
mock_response=MOCK_LLM_RESPONSE,
max_tokens=None,
)
@@ -290,4 +291,5 @@ def test_multiple_tool_calls_streaming(default_multi_llm: DefaultMultiLLM) -> No
timeout=30,
parallel_tool_calls=False,
mock_response=MOCK_LLM_RESPONSE,
max_tokens=None,
)

View File

@@ -1,8 +1,14 @@
import contextvars
import time
import pytest
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.threadpool_concurrency import wait_on_background
# Create a context variable for testing
test_context_var = contextvars.ContextVar("test_var", default="default")
def test_run_with_timeout_completes() -> None:
@@ -59,3 +65,86 @@ def test_run_with_timeout_with_args_and_kwargs() -> None:
# Test with positional and keyword args
result2 = run_with_timeout(1.0, complex_function, x=5, y=3, multiply=True)
assert result2 == 15
def test_run_in_background_and_wait_success() -> None:
"""Test that run_in_background and wait_on_background work correctly for successful execution"""
def background_function(x: int) -> int:
time.sleep(0.1) # Small delay to ensure it's actually running in background
return x * 2
# Start the background task
task = run_in_background(background_function, 21)
# Verify we can do other work while task is running
start_time = time.time()
result = wait_on_background(task)
elapsed = time.time() - start_time
assert result == 42
assert elapsed >= 0.1 # Verify we actually waited for the sleep
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
def test_run_in_background_propagates_exceptions() -> None:
"""Test that exceptions in background tasks are properly propagated"""
def error_function() -> None:
time.sleep(0.1) # Small delay to ensure it's actually running in background
raise ValueError("Test background error")
task = run_in_background(error_function)
with pytest.raises(ValueError) as exc_info:
wait_on_background(task)
assert "Test background error" in str(exc_info.value)
def test_run_in_background_with_args_and_kwargs() -> None:
"""Test that args and kwargs are properly passed to the background function"""
def complex_function(x: int, y: int, multiply: bool = False) -> int:
time.sleep(0.1) # Small delay to ensure it's actually running in background
if multiply:
return x * y
return x + y
# Test with args
task1 = run_in_background(complex_function, 5, 3)
result1 = wait_on_background(task1)
assert result1 == 8
# Test with args and kwargs
task2 = run_in_background(complex_function, 5, 3, multiply=True)
result2 = wait_on_background(task2)
assert result2 == 15
def test_multiple_background_tasks() -> None:
"""Test running multiple background tasks concurrently"""
def slow_add(x: int, y: int) -> int:
time.sleep(0.2) # Make each task take some time
return x + y
# Start multiple tasks
start_time = time.time()
task1 = run_in_background(slow_add, 1, 2)
task2 = run_in_background(slow_add, 3, 4)
task3 = run_in_background(slow_add, 5, 6)
# Wait for all results
result1 = wait_on_background(task1)
result2 = wait_on_background(task2)
result3 = wait_on_background(task3)
elapsed = time.time() - start_time
# Verify results
assert result1 == 3
assert result2 == 7
assert result3 == 11
# Verify tasks ran in parallel (total time should be ~0.2s, not ~0.6s)
assert 0.2 <= elapsed < 0.4 # Allow some buffer for test environment variations

View File

@@ -4,7 +4,9 @@ import time
from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.threadpool_concurrency import wait_on_background
# Create a test contextvar
test_var = contextvars.ContextVar("test_var", default="default")
@@ -129,3 +131,39 @@ def test_contextvar_isolation_between_runs() -> None:
# Verify second run results
assert all(result in ["thread3", "thread4"] for result in second_results)
def test_run_in_background_preserves_contextvar() -> None:
"""Test that run_in_background preserves contextvar values and modifications are isolated"""
def modify_and_sleep() -> tuple[str, str]:
"""Modifies contextvar, sleeps, and returns original, modified, and final values"""
original = test_var.get()
test_var.set("modified_in_background")
time.sleep(0.1) # Ensure we can check main thread during execution
final = test_var.get()
return original, final
# Set initial value in main thread
token = test_var.set("initial_value")
try:
# Start background task
task = run_in_background(modify_and_sleep)
# Verify main thread value remains unchanged while task runs
assert test_var.get() == "initial_value"
# Get results from background thread
original, modified = wait_on_background(task)
# Verify the background thread:
# 1. Saw the initial value
assert original == "initial_value"
# 2. Successfully modified its own copy
assert modified == "modified_in_background"
# Verify main thread value is still unchanged after task completion
assert test_var.get() == "initial_value"
finally:
# Clean up
test_var.reset(token)

View File

@@ -254,6 +254,9 @@ services:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
# optional, only for debugging purposes
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -431,3 +434,4 @@ volumes:
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -209,6 +209,9 @@ services:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
# optional, only for debugging purposes
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -384,3 +387,4 @@ volumes:
# Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -244,6 +244,8 @@ services:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -421,3 +423,4 @@ volumes:
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -54,6 +54,9 @@ services:
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
extra_hosts:
- "host.docker.internal:host-gateway"
# optional, only for debugging purposes
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -233,3 +236,4 @@ volumes:
# Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -68,6 +68,8 @@ services:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -229,3 +231,4 @@ volumes:
# Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -32,6 +32,8 @@ services:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -73,6 +75,8 @@ services:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -260,3 +264,4 @@ volumes:
# Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -62,6 +62,8 @@ services:
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=True
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -221,3 +223,4 @@ volumes:
type: none
o: bind
device: ${DANSWER_VESPA_DATA_DIR:-./vespa_data}
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -1,17 +1,14 @@
import {
AnthropicIcon,
AmazonIcon,
AWSIcon,
AzureIcon,
CPUIcon,
MicrosoftIconSVG,
MistralIcon,
MetaIcon,
GeminiIcon,
AnthropicSVG,
IconProps,
OpenAIISVG,
DeepseekIcon,
OpenAISVG,
} from "@/components/icons/icons";
export interface CustomConfigKey {
@@ -74,7 +71,7 @@ export interface LLMProviderDescriptor {
}
export const getProviderIcon = (providerName: string, modelName?: string) => {
const modelIconMap: Record<
const iconMap: Record<
string,
({ size, className }: IconProps) => JSX.Element
> = {
@@ -86,34 +83,30 @@ export const getProviderIcon = (providerName: string, modelName?: string) => {
gemini: GeminiIcon,
deepseek: DeepseekIcon,
claude: AnthropicIcon,
anthropic: AnthropicIcon,
openai: OpenAISVG,
microsoft: MicrosoftIconSVG,
meta: MetaIcon,
google: GeminiIcon,
};
const modelNameToIcon = (
modelName: string,
fallbackIcon: ({ size, className }: IconProps) => JSX.Element
): (({ size, className }: IconProps) => JSX.Element) => {
const lowerModelName = modelName?.toLowerCase();
for (const [key, icon] of Object.entries(modelIconMap)) {
if (lowerModelName?.includes(key)) {
// First check if provider name directly matches an icon
if (providerName.toLowerCase() in iconMap) {
return iconMap[providerName.toLowerCase()];
}
// Then check if model name contains any of the keys
if (modelName) {
const lowerModelName = modelName.toLowerCase();
for (const [key, icon] of Object.entries(iconMap)) {
if (lowerModelName.includes(key)) {
return icon;
}
}
return fallbackIcon;
};
switch (providerName) {
case "openai":
// Special cases for openai based on modelName
return modelNameToIcon(modelName || "", OpenAIISVG);
case "anthropic":
return AnthropicSVG;
case "bedrock":
return AWSIcon;
case "azure":
return AzureIcon;
default:
return modelNameToIcon(modelName || "", CPUIcon);
}
// Fallback to CPU icon if no matches
return CPUIcon;
};
export const isAnthropic = (provider: string, modelName: string) =>

View File

@@ -185,7 +185,10 @@ export const FilterComponent = forwardRef<
hasActiveFilters ? "border-primary bg-primary/5" : ""
}`}
>
<SortIcon size={20} className="text-neutral-800" />
<SortIcon
size={20}
className="text-neutral-800 dark:text-neutral-200"
/>
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent
@@ -365,7 +368,7 @@ export const FilterComponent = forwardRef<
{hasActiveFilters && (
<div className="absolute -top-1 -right-1">
<Badge className="h-2 bg-red-400 border-red-400 w-2 p-0 border-2 flex items-center justify-center" />
<Badge className="h-2 !bg-red-400 !border-red-400 w-2 p-0 border-2 flex items-center justify-center" />
</div>
)}
</div>

View File

@@ -3102,27 +3102,6 @@ export const OpenAISVG = ({
);
};
export const AnthropicSVG = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 92.2 65"
xmlSpace="preserve"
fill="currentColor"
>
<path
fill="currentColor"
d="M66.5,0H52.4l25.7,65h14.1L66.5,0z M25.7,0L0,65h14.4l5.3-13.6h26.9L51.8,65h14.4L40.5,0C40.5,0,25.7,0,25.7,0z M24.3,39.3l8.8-22.8l8.8,22.8H24.3z"
/>
</svg>
);
};
export const SourcesIcon = ({
size = 16,
className = defaultTailwindCSS,