mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-20 01:05:46 +00:00
Compare commits
10 Commits
bb
...
agent-sear
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1c23cf574c | ||
|
|
0ccf78ac52 | ||
|
|
02b4b4bf0d | ||
|
|
dade11a2e6 | ||
|
|
188a5f0d62 | ||
|
|
89c0b1ad37 | ||
|
|
8b20fd31b6 | ||
|
|
6a73245986 | ||
|
|
dd73fdcd08 | ||
|
|
768456609a |
@@ -133,4 +133,3 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
|
||||
@@ -365,9 +365,7 @@ def confluence_doc_sync(
|
||||
|
||||
slim_docs = []
|
||||
logger.debug("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents(
|
||||
callback=callback
|
||||
):
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
|
||||
@@ -15,7 +15,6 @@ logger = setup_logger()
|
||||
def _get_slim_doc_generator(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
gmail_connector: GmailConnector,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
start_time = (
|
||||
@@ -25,9 +24,7 @@ def _get_slim_doc_generator(
|
||||
)
|
||||
|
||||
return gmail_connector.retrieve_all_slim_documents(
|
||||
start=start_time,
|
||||
end=current_time.timestamp(),
|
||||
callback=callback,
|
||||
start=start_time, end=current_time.timestamp()
|
||||
)
|
||||
|
||||
|
||||
@@ -43,9 +40,7 @@ def gmail_doc_sync(
|
||||
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
|
||||
gmail_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(
|
||||
cc_pair, gmail_connector, callback=callback
|
||||
)
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, gmail_connector)
|
||||
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
|
||||
@@ -21,7 +21,6 @@ _PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {}
|
||||
def _get_slim_doc_generator(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
start_time = (
|
||||
@@ -31,9 +30,7 @@ def _get_slim_doc_generator(
|
||||
)
|
||||
|
||||
return google_drive_connector.retrieve_all_slim_documents(
|
||||
start=start_time,
|
||||
end=current_time.timestamp(),
|
||||
callback=callback,
|
||||
start=start_time, end=current_time.timestamp()
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -20,11 +20,19 @@ def _get_slack_document_ids_and_channels(
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents()
|
||||
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
@@ -32,14 +40,6 @@ def _get_slack_document_ids_and_channels(
|
||||
channel_doc_map[channel_id] = []
|
||||
channel_doc_map[channel_id].append(doc_metadata.id)
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
return channel_doc_map
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ class CoreState(BaseModel):
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
base_question: str = ""
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
@@ -18,4 +17,4 @@ class SubgraphCoreState(BaseModel):
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
log_messages: Annotated[list[str], add]
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
@@ -12,12 +12,39 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
|
||||
SubQuestionAnswerCheckUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
binary_string_test,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import AgentLLMErrorType
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. The sub-answer will be treated as 'relevant'",
|
||||
rate_limit="LLM Rate Limit Error. The sub-answer will be treated as 'relevant'",
|
||||
general_error="General LLM Error. The sub-answer will be treated as 'relevant'",
|
||||
)
|
||||
|
||||
|
||||
def check_sub_answer(
|
||||
@@ -53,14 +80,46 @@ def check_sub_answer(
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
response: BaseMessage | None = None
|
||||
try:
|
||||
response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
|
||||
)
|
||||
)
|
||||
|
||||
quality_str: str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
answer_quality = "yes" in quality_str.lower()
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - check sub answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - check sub answer")
|
||||
|
||||
if agent_error:
|
||||
answer_quality = True
|
||||
log_result = agent_error.error_result
|
||||
|
||||
else:
|
||||
if response:
|
||||
quality_str: str = cast(str, response.content)
|
||||
answer_quality = binary_string_test(
|
||||
text=quality_str, positive_value=AGENT_POSITIVE_VALUE_STR
|
||||
)
|
||||
|
||||
else:
|
||||
answer_quality = True
|
||||
quality_str = "yes - because LLM error"
|
||||
|
||||
log_result = f"Answer quality: {quality_str}"
|
||||
|
||||
return SubQuestionAnswerCheckUpdate(
|
||||
answer_quality=answer_quality,
|
||||
@@ -69,7 +128,7 @@ def check_sub_answer(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="check sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result=f"Answer quality: {quality_str}",
|
||||
result=log_result,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -16,6 +16,20 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_sub_question_answer_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
LLM_ANSWER_ERROR_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
@@ -30,11 +44,20 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
rate_limit="LLM Rate Limit Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
general_error="General LLM Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
)
|
||||
|
||||
|
||||
def generate_sub_answer(
|
||||
state: AnswerQuestionState,
|
||||
@@ -57,6 +80,8 @@ def generate_sub_answer(
|
||||
|
||||
if len(context_docs) == 0:
|
||||
answer_str = NO_RECOVERED_DOCS
|
||||
cited_documents: list = []
|
||||
log_results = "No documents retrieved"
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
@@ -79,41 +104,67 @@ def generate_sub_answer(
|
||||
|
||||
response: list[str | list[str | dict[str, Any]]] = []
|
||||
dispatch_timings: list[float] = []
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
|
||||
try:
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
logger.debug(
|
||||
f"Average dispatch time: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate sub answer")
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate sub answer")
|
||||
|
||||
answer_citation_ids = get_answer_citation_ids(answer_str)
|
||||
cited_documents = [
|
||||
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
|
||||
]
|
||||
if agent_error:
|
||||
answer_str = LLM_ANSWER_ERROR_MESSAGE
|
||||
cited_documents = []
|
||||
log_results = (
|
||||
agent_error.error_result
|
||||
or "Sub-answer generation failed due to LLM error"
|
||||
)
|
||||
|
||||
else:
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
answer_citation_ids = get_answer_citation_ids(answer_str)
|
||||
cited_documents = [
|
||||
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
|
||||
]
|
||||
log_results = None
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
@@ -131,7 +182,7 @@ def generate_sub_answer(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="generate sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result="",
|
||||
result=log_results or "",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -42,10 +42,8 @@ class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):
|
||||
|
||||
|
||||
class SubQuestionAnsweringInput(SubgraphCoreState):
|
||||
question: str = ""
|
||||
question_id: str = (
|
||||
"" # 0_0 is original question, everything else is <level>_<question_num>.
|
||||
)
|
||||
question: str
|
||||
question_id: str
|
||||
# level 0 is original question and first decomposition, level 1 is follow up, etc
|
||||
# question_num is a unique number per original question per level.
|
||||
|
||||
|
||||
@@ -26,7 +26,18 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
@@ -42,12 +53,16 @@ from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_ci
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
|
||||
)
|
||||
@@ -57,6 +72,12 @@ from onyx.prompts.agent_search import (
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. The initial answer could not be generated.",
|
||||
rate_limit="LLM Rate Limit Error. The initial answer could not be generated.",
|
||||
general_error="General LLM Error. The initial answer could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
def generate_initial_answer(
|
||||
state: SubQuestionRetrievalState,
|
||||
@@ -224,30 +245,82 @@ def generate_initial_answer(
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
|
||||
try:
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate initial answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate initial answer")
|
||||
|
||||
if agent_error:
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
StreamingError(
|
||||
error=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=None,
|
||||
error=AgentErrorLoggingFormat(
|
||||
error_message=agent_error.error_message or "An LLM error occurred",
|
||||
error_type=agent_error.error_type,
|
||||
error_result=agent_error.error_result,
|
||||
),
|
||||
initial_agent_stats=None,
|
||||
generated_sub_questions=sub_questions,
|
||||
agent_base_end_time=None,
|
||||
agent_base_metrics=None,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate initial answer",
|
||||
node_name="generate initial answer",
|
||||
node_start_time=node_start_time,
|
||||
result=agent_error.error_result or "An LLM error occurred",
|
||||
)
|
||||
],
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
logger.debug(
|
||||
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
|
||||
@@ -25,7 +25,7 @@ def validate_initial_answer(
|
||||
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
|
||||
)
|
||||
|
||||
verdict = True
|
||||
verdict = True # not actually required as already streamed out. Refinement will do similar
|
||||
|
||||
return InitialAnswerQualityUpdate(
|
||||
initial_answer_quality_eval=verdict,
|
||||
|
||||
@@ -23,6 +23,18 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
@@ -33,6 +45,11 @@ from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
|
||||
)
|
||||
@@ -43,6 +60,12 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. Sub-questions could not be generated.",
|
||||
rate_limit="LLM Rate Limit Error. Sub-questions could not be generated.",
|
||||
general_error="General LLM Error. Sub-questions could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
def decompose_orig_question(
|
||||
state: SubQuestionRetrievalState,
|
||||
@@ -112,11 +135,35 @@ def decompose_orig_question(
|
||||
)
|
||||
|
||||
# dispatches custom events for subquestion tokens, adding in subquestion ids.
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(msg),
|
||||
dispatch_subquestion(0, writer),
|
||||
sep_callback=dispatch_subquestion_sep(0, writer),
|
||||
)
|
||||
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
|
||||
try:
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(0, writer),
|
||||
sep_callback=dispatch_subquestion_sep(0, writer),
|
||||
)
|
||||
except LLMTimeoutError as e:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - decompose orig question")
|
||||
raise e # fail loudly on this critical step
|
||||
except LLMRateLimitError as e:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - decompose orig question")
|
||||
raise e
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
@@ -125,19 +172,19 @@ def decompose_orig_question(
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
deomposition_response = merge_content(*streamed_tokens)
|
||||
if agent_error:
|
||||
initial_sub_questions: list[str] = []
|
||||
log_result = agent_error.error_result
|
||||
else:
|
||||
deomposition_response = merge_content(*streamed_tokens)
|
||||
|
||||
# this call should only return strings. Commenting out for efficiency
|
||||
# assert [type(tok) == str for tok in streamed_tokens]
|
||||
list_of_subqs = cast(str, deomposition_response).split("\n")
|
||||
|
||||
# use no-op cast() instead of str() which runs code
|
||||
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
|
||||
list_of_subqs = cast(str, deomposition_response).split("\n")
|
||||
|
||||
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
|
||||
initial_sub_questions = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
|
||||
log_result = f"decomposed original question into {len(initial_sub_questions)} subquestions"
|
||||
|
||||
return InitialQuestionDecompositionUpdate(
|
||||
initial_sub_questions=decomp_list,
|
||||
initial_sub_questions=initial_sub_questions,
|
||||
agent_start_time=agent_start_time,
|
||||
agent_refined_start_time=None,
|
||||
agent_refined_end_time=None,
|
||||
@@ -151,7 +198,7 @@ def decompose_orig_question(
|
||||
graph_component="initial - generate sub answers",
|
||||
node_name="decompose original question",
|
||||
node_start_time=node_start_time,
|
||||
result=f"decomposed original question into {len(decomp_list)} subquestions",
|
||||
result=log_result,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -252,9 +252,7 @@ if __name__ == "__main__":
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
|
||||
inputs = MainInput(
|
||||
base_question=graph_config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
inputs = MainInput(log_messages=[])
|
||||
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
@@ -10,14 +11,37 @@ from onyx.agents.agent_search.deep_search.main.states import (
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out, and the answers could not be compared.",
|
||||
rate_limit="The LLM encountered a rate limit, and the answers could not be compared.",
|
||||
general_error="The LLM encountered an error, and the answers could not be compared.",
|
||||
)
|
||||
|
||||
|
||||
def compare_answers(
|
||||
@@ -40,15 +64,46 @@ def compare_answers(
|
||||
|
||||
msg = [HumanMessage(content=compare_answers_prompt)]
|
||||
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
# Get the rewritten queries in a defined format
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
resp: BaseMessage | None = None
|
||||
refined_answer_improvement: bool | None = None
|
||||
# no need to stream this
|
||||
resp = model.invoke(msg)
|
||||
try:
|
||||
resp = model.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
refined_answer_improvement = (
|
||||
isinstance(resp.content, str) and "yes" in resp.content.lower()
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - compare answers")
|
||||
# continue as True in this support step
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - compare answers")
|
||||
# continue as True in this support step
|
||||
|
||||
if agent_error or resp is None:
|
||||
refined_answer_improvement = True
|
||||
if agent_error:
|
||||
log_result = agent_error.error_result
|
||||
else:
|
||||
log_result = "An answer could not be generated."
|
||||
|
||||
else:
|
||||
refined_answer_improvement = (
|
||||
isinstance(resp.content, str) and "yes" in resp.content.lower()
|
||||
)
|
||||
log_result = f"Answer comparison: {refined_answer_improvement}"
|
||||
|
||||
write_custom_event(
|
||||
"refined_answer_improvement",
|
||||
@@ -65,7 +120,7 @@ def compare_answers(
|
||||
graph_component="main",
|
||||
node_name="compare answers",
|
||||
node_start_time=node_start_time,
|
||||
result=f"Answer comparison: {refined_answer_improvement}",
|
||||
result=log_result,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -21,6 +21,18 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
format_entity_term_extraction,
|
||||
@@ -30,10 +42,25 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT,
|
||||
)
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out. The sub-questions could not be generated.",
|
||||
rate_limit="The LLM encountered a rate limit. The sub-questions could not be generated.",
|
||||
general_error="The LLM encountered an error. The sub-questions could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
def create_refined_sub_questions(
|
||||
@@ -96,29 +123,65 @@ def create_refined_sub_questions(
|
||||
# Grader
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(msg),
|
||||
dispatch_subquestion(1, writer),
|
||||
sep_callback=dispatch_subquestion_sep(1, writer),
|
||||
)
|
||||
response = merge_content(*streamed_tokens)
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
try:
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(1, writer),
|
||||
sep_callback=dispatch_subquestion_sep(1, writer),
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - create refined sub questions")
|
||||
|
||||
if isinstance(response, str):
|
||||
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
|
||||
else:
|
||||
raise ValueError("LLM response is not a string")
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - create refined sub questions")
|
||||
|
||||
refined_sub_question_dict = {}
|
||||
for sub_question_num, sub_question in enumerate(parsed_response):
|
||||
refined_sub_question = RefinementSubQuestion(
|
||||
sub_question=sub_question,
|
||||
sub_question_id=make_question_id(1, sub_question_num + 1),
|
||||
verified=False,
|
||||
answered=False,
|
||||
answer="",
|
||||
if agent_error:
|
||||
refined_sub_question_dict: dict[int, RefinementSubQuestion] = {}
|
||||
log_result = agent_error.error_result
|
||||
write_custom_event(
|
||||
"refined_sub_question_creation_error",
|
||||
StreamingError(
|
||||
error="Your LLM was not able to create refined sub questions in time and timed out. Please try again.",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
|
||||
else:
|
||||
response = merge_content(*streamed_tokens)
|
||||
|
||||
if isinstance(response, str):
|
||||
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
|
||||
else:
|
||||
raise ValueError("LLM response is not a string")
|
||||
|
||||
refined_sub_question_dict = {}
|
||||
for sub_question_num, sub_question in enumerate(parsed_response):
|
||||
refined_sub_question = RefinementSubQuestion(
|
||||
sub_question=sub_question,
|
||||
sub_question_id=make_question_id(1, sub_question_num + 1),
|
||||
verified=False,
|
||||
answered=False,
|
||||
answer="",
|
||||
)
|
||||
|
||||
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
|
||||
|
||||
log_result = f"Created {len(refined_sub_question_dict)} refined sub questions"
|
||||
|
||||
return RefinedQuestionDecompositionUpdate(
|
||||
refined_sub_questions=refined_sub_question_dict,
|
||||
@@ -128,7 +191,7 @@ def create_refined_sub_questions(
|
||||
graph_component="main",
|
||||
node_name="create refined sub questions",
|
||||
node_start_time=node_start_time,
|
||||
result=f"Created {len(refined_sub_question_dict)} refined sub questions",
|
||||
result=log_result,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -26,6 +26,19 @@ def decide_refinement_need(
|
||||
|
||||
decision = True # TODO: just for current testing purposes
|
||||
|
||||
if state.error:
|
||||
return RequireRefinemenEvalUpdate(
|
||||
require_refined_answer_eval=False,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="decide refinement need",
|
||||
node_start_time=node_start_time,
|
||||
result="Timeout Error",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
log_messages = [
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
|
||||
@@ -21,6 +21,9 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
|
||||
@@ -81,6 +84,7 @@ def extract_entities_terms(
|
||||
# Grader
|
||||
llm_response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.agents.agent_search.deep_search.main.models import (
|
||||
AgentRefinedMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
RefinedAnswerUpdate,
|
||||
@@ -23,7 +22,18 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InferenceSection
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
@@ -43,8 +53,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
|
||||
)
|
||||
@@ -56,6 +72,15 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out. The refined answer could not be generated.",
|
||||
rate_limit="The LLM encountered a rate limit. The refined answer could not be generated.",
|
||||
general_error="The LLM encountered an error. The refined answer could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
def generate_refined_answer(
|
||||
@@ -231,28 +256,80 @@ def generate_refined_answer(
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
|
||||
start_stream_token = datetime.now()
|
||||
try:
|
||||
for message in model.stream(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=1,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate refined answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate refined answer")
|
||||
|
||||
if agent_error:
|
||||
write_custom_event(
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=1,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
"initial_agent_answer",
|
||||
StreamingError(
|
||||
error=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
return RefinedAnswerUpdate(
|
||||
refined_answer=None,
|
||||
refined_answer_quality=False, # TODO: replace this with the actual check value
|
||||
refined_agent_stats=None,
|
||||
agent_refined_end_time=None,
|
||||
agent_refined_metrics=AgentRefinedMetrics(
|
||||
refined_doc_boost_factor=0.0,
|
||||
refined_question_boost_factor=0.0,
|
||||
duration_s=None,
|
||||
),
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="generate refined answer",
|
||||
node_start_time=node_start_time,
|
||||
result=agent_error.error_result or "An LLM error occurred",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
@@ -266,49 +343,6 @@ def generate_refined_answer(
|
||||
revision_question_efficiency=revision_question_efficiency,
|
||||
)
|
||||
|
||||
logger.debug(f"\n\n---INITIAL ANSWER ---\n\n Answer:\n Agent: {initial_answer}")
|
||||
logger.debug("-" * 10)
|
||||
logger.debug(f"\n\n---REVISED AGENT ANSWER ---\n\n Answer:\n Agent: {answer}")
|
||||
|
||||
logger.debug("-" * 100)
|
||||
|
||||
if state.initial_agent_stats:
|
||||
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio", "--"
|
||||
)
|
||||
initial_support_boost_factor = (
|
||||
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
|
||||
)
|
||||
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
|
||||
"num_verified_documents", "--"
|
||||
)
|
||||
initial_verified_docs_avg_score = (
|
||||
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
|
||||
)
|
||||
initial_sub_questions_verified_docs = (
|
||||
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
|
||||
)
|
||||
|
||||
logger.debug("INITIAL AGENT STATS")
|
||||
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
|
||||
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
|
||||
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
|
||||
logger.debug(
|
||||
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
|
||||
)
|
||||
if refined_agent_stats:
|
||||
logger.debug("-" * 10)
|
||||
logger.debug("REFINED AGENT STATS")
|
||||
logger.debug(
|
||||
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
|
||||
)
|
||||
|
||||
agent_refined_end_time = datetime.now()
|
||||
if state.agent_refined_start_time:
|
||||
agent_refined_duration = (
|
||||
|
||||
@@ -17,6 +17,7 @@ from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
@@ -76,6 +77,7 @@ class InitialAnswerUpdate(LoggerUpdate):
|
||||
"""
|
||||
|
||||
initial_answer: str | None = None
|
||||
error: AgentErrorLoggingFormat | None = None
|
||||
initial_agent_stats: InitialAgentResultStats | None = None
|
||||
generated_sub_questions: list[str] = []
|
||||
agent_base_end_time: datetime | None = None
|
||||
@@ -88,6 +90,7 @@ class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
|
||||
"""
|
||||
|
||||
refined_answer: str | None = None
|
||||
error: AgentErrorLoggingFormat | None = None
|
||||
refined_agent_stats: RefinedAgentStats | None = None
|
||||
refined_answer_quality: bool = False
|
||||
|
||||
|
||||
@@ -16,14 +16,40 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
QUERY_REWRITING_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="Query rewriting failed due to LLM timeout - the original question will be used.",
|
||||
rate_limit="Query rewriting failed due to LLM rate limit - the original question will be used.",
|
||||
general_error="Query rewriting failed due to LLM error - the original question will be used.",
|
||||
)
|
||||
|
||||
|
||||
def expand_queries(
|
||||
@@ -54,13 +80,43 @@ def expand_queries(
|
||||
)
|
||||
]
|
||||
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(prompt=msg), dispatch_subquery(level, question_num, writer)
|
||||
)
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
llm_response_list: list[BaseMessage_Content] = []
|
||||
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
try:
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
),
|
||||
dispatch_subquery(level, question_num, writer),
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - expand queries")
|
||||
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - expand queries")
|
||||
# use subquestion as query if query generation fails
|
||||
if agent_error:
|
||||
llm_response = ""
|
||||
rewritten_queries = [question]
|
||||
log_result = agent_error.error_result
|
||||
else:
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[
|
||||
0
|
||||
].content
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
|
||||
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=rewritten_queries,
|
||||
@@ -69,7 +125,7 @@ def expand_queries(
|
||||
graph_component="shared - expanded retrieval",
|
||||
node_name="expand queries",
|
||||
node_start_time=node_start_time,
|
||||
result=f"Number of expanded queries: {len(rewritten_queries)}",
|
||||
result=log_result,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
@@ -10,12 +11,41 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
||||
DocVerificationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
binary_string_test,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
DOCUMENT_VERIFICATION_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out. The document could not be verified. The document will be treated as 'relevant'",
|
||||
rate_limit="The LLM encountered a rate limit. The document could not be verified. The document will be treated as 'relevant'",
|
||||
general_error="The LLM encountered an error. The document could not be verified. The document will be treated as 'relevant'",
|
||||
)
|
||||
|
||||
|
||||
def verify_documents(
|
||||
@@ -26,7 +56,7 @@ def verify_documents(
|
||||
|
||||
Args:
|
||||
state (DocVerificationInput): The current state
|
||||
config (RunnableConfig): Configuration containing ProSearchConfig
|
||||
config (RunnableConfig): Configuration containing AgentSearchConfig
|
||||
|
||||
Updates:
|
||||
verified_documents: list[InferenceSection]
|
||||
@@ -51,11 +81,42 @@ def verify_documents(
|
||||
)
|
||||
]
|
||||
|
||||
response = fast_llm.invoke(msg)
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
response: BaseMessage | None = None
|
||||
|
||||
verified_documents = []
|
||||
if isinstance(response.content, str) and "yes" in response.content.lower():
|
||||
verified_documents.append(retrieved_document_to_verify)
|
||||
try:
|
||||
response = fast_llm.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
|
||||
except LLMTimeoutError:
|
||||
# In this case, we decide to continue and don't raise an error, as
|
||||
# little harm in letting some docs through that are less relevant.
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - verify documents")
|
||||
except LLMRateLimitError:
|
||||
# In this case, we decide to continue and don't raise an error, as
|
||||
# little harm in letting some docs through that are less relevant.
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - verify documents")
|
||||
|
||||
if agent_error or response is None:
|
||||
verified_documents = [retrieved_document_to_verify]
|
||||
|
||||
else:
|
||||
verified_documents = []
|
||||
if isinstance(response.content, str) and binary_string_test(
|
||||
text=response.content, positive_value=AGENT_POSITIVE_VALUE_STR
|
||||
):
|
||||
verified_documents.append(retrieved_document_to_verify)
|
||||
|
||||
return DocVerificationUpdate(
|
||||
verified_documents=verified_documents,
|
||||
|
||||
@@ -21,9 +21,13 @@ from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class ExpandedRetrievalInput(SubgraphCoreState):
|
||||
question: str = ""
|
||||
base_search: bool = False
|
||||
# exception from 'no default value'for LangGraph input states
|
||||
# Here, sub_question_id default Nonoe implies usage for the
|
||||
# original question. This is sometimes needed for nested sub-graphs
|
||||
|
||||
sub_question_id: str | None = None
|
||||
question: str
|
||||
base_search: bool
|
||||
|
||||
|
||||
## Update/Return States
|
||||
@@ -88,4 +92,4 @@ class DocVerificationInput(ExpandedRetrievalInput):
|
||||
|
||||
|
||||
class RetrievalInput(ExpandedRetrievalInput):
|
||||
query_to_retrieve: str = ""
|
||||
query_to_retrieve: str
|
||||
|
||||
@@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||
main_graph_builder as main_graph_builder_a,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
MainInput as MainInput_a,
|
||||
MainInput as MainInput,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
@@ -21,6 +21,7 @@ from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
@@ -33,6 +34,7 @@ from onyx.llm.factory import get_default_llms
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
||||
@@ -72,13 +74,15 @@ def _parse_agent_event(
|
||||
return cast(AnswerPacket, event["data"])
|
||||
elif event["name"] == "refined_answer_improvement":
|
||||
return cast(RefinedAnswerImprovement, event["data"])
|
||||
elif event["name"] == "refined_sub_question_creation_error":
|
||||
return cast(StreamingError, event["data"])
|
||||
return None
|
||||
|
||||
|
||||
def manage_sync_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
graph_input: BasicInput | MainInput_a,
|
||||
graph_input: BasicInput | MainInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
message_id = config.persistence.message_id if config.persistence else None
|
||||
for event in compiled_graph.stream(
|
||||
@@ -92,7 +96,7 @@ def manage_sync_streaming(
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
input: BasicInput | MainInput_a,
|
||||
input: BasicInput | MainInput,
|
||||
) -> AnswerStream:
|
||||
config.behavior.perform_initial_search_decomposition = (
|
||||
INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
@@ -123,9 +127,7 @@ def run_main_graph(
|
||||
) -> AnswerStream:
|
||||
compiled_graph = load_compiled_graph()
|
||||
|
||||
input = MainInput_a(
|
||||
base_question=config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
input = MainInput(log_messages=[])
|
||||
|
||||
# Agent search is not a Tool per se, but this is helpful for the frontend
|
||||
yield ToolCallKickoff(
|
||||
@@ -172,9 +174,7 @@ if __name__ == "__main__":
|
||||
# search_request.persona = get_persona_by_id(1, None, db_session)
|
||||
# config.perform_initial_search_path_decision = False
|
||||
config.behavior.perform_initial_search_decomposition = True
|
||||
input = MainInput_a(
|
||||
base_question=config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
input = MainInput(log_messages=[])
|
||||
|
||||
tool_responses: list = []
|
||||
for output in run_graph(compiled_graph, config, input):
|
||||
|
||||
@@ -150,3 +150,17 @@ def get_prompt_enrichment_components(
|
||||
history=history,
|
||||
date_str=date_str,
|
||||
)
|
||||
|
||||
|
||||
def binary_string_test(text: str, positive_value: str = "yes") -> bool:
|
||||
"""
|
||||
Tests if a string contains a positive value (case-insensitive).
|
||||
|
||||
Args:
|
||||
text: The string to test
|
||||
positive_value: The value to look for (defaults to "yes")
|
||||
|
||||
Returns:
|
||||
True if the positive value is found in the text
|
||||
"""
|
||||
return positive_value.lower() in text.lower()
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
from enum import Enum
|
||||
|
||||
AGENT_LLM_TIMEOUT_MESSAGE = "The agent timed out. Please try again."
|
||||
AGENT_LLM_ERROR_MESSAGE = "The agent encountered an error. Please try again."
|
||||
AGENT_LLM_RATELIMIT_MESSAGE = (
|
||||
"The agent encountered a rate limit error. Please try again."
|
||||
)
|
||||
LLM_ANSWER_ERROR_MESSAGE = "The question was not answered due to an LLM error."
|
||||
|
||||
AGENT_POSITIVE_VALUE_STR = "yes"
|
||||
AGENT_NEGATIVE_VALUE_STR = "no"
|
||||
|
||||
|
||||
class AgentLLMErrorType(str, Enum):
|
||||
TIMEOUT = "timeout"
|
||||
RATE_LIMIT = "rate_limit"
|
||||
GENERAL_ERROR = "general_error"
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
@@ -56,6 +58,12 @@ class InitialAgentResultStats(BaseModel):
|
||||
agent_effectiveness: dict[str, float | int | None]
|
||||
|
||||
|
||||
class AgentErrorLoggingFormat(BaseModel):
|
||||
error_message: str
|
||||
error_type: str
|
||||
error_result: str | None = None
|
||||
|
||||
|
||||
class RefinedAgentStats(BaseModel):
|
||||
revision_doc_efficiency: float | None
|
||||
revision_question_efficiency: float | None
|
||||
@@ -126,3 +134,12 @@ class AgentPromptEnrichmentComponents(BaseModel):
|
||||
persona_prompts: PersonaPromptExpressions
|
||||
history: str
|
||||
date_str: str
|
||||
|
||||
|
||||
class LLMNodeErrorStrings(BaseModel):
|
||||
timeout: str = "LLM Timeout Error"
|
||||
rate_limit: str = "LLM Rate Limit Error"
|
||||
general_error: str = "General LLM Error"
|
||||
|
||||
|
||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
|
||||
@@ -20,6 +20,7 @@ from onyx.agents.agent_search.models import GraphInputs
|
||||
from onyx.agents.agent_search.models import GraphPersistence
|
||||
from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
@@ -34,6 +35,9 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
@@ -46,6 +50,8 @@ from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.persona import Persona
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.prompts.agent_search import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
@@ -65,8 +71,9 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# Post-processing
|
||||
@@ -372,8 +379,24 @@ def summarize_history(
|
||||
)
|
||||
)
|
||||
|
||||
history_response = llm.invoke(history_context_prompt)
|
||||
try:
|
||||
history_response = llm.invoke(
|
||||
history_context_prompt,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
logger.error("LLM Timeout Error - summarize history")
|
||||
return (
|
||||
history # this is what is done at this point anyway, so we default to this
|
||||
)
|
||||
except LLMRateLimitError:
|
||||
logger.error("LLM Rate Limit Error - summarize history")
|
||||
return (
|
||||
history # this is what is done at this point anyway, so we default to this
|
||||
)
|
||||
|
||||
assert isinstance(history_response.content, str)
|
||||
|
||||
return history_response.content
|
||||
|
||||
|
||||
|
||||
@@ -84,10 +84,8 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
EXTRA_CONCURRENCY = 4 # small extra fudge factor for connection limits
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -18,153 +18,238 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# hack to slow down task dispatch in the cloud until
|
||||
# we have a better implementation (backpressure, etc)
|
||||
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 8
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
beat_task_templates: list[dict] = []
|
||||
|
||||
beat_task_templates.extend(
|
||||
[
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-pruning",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"schedule": timedelta(minutes=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
cloud_task: dict[str, Any] = {}
|
||||
|
||||
# constant options for cloud beat task generators
|
||||
task_schedule: timedelta = task["schedule"]
|
||||
cloud_task["schedule"] = task_schedule * CLOUD_BEAT_SCHEDULE_MULTIPLIER
|
||||
cloud_task["options"] = {}
|
||||
cloud_task["options"]["priority"] = OnyxCeleryPriority.HIGHEST
|
||||
cloud_task["options"]["expires"] = BEAT_EXPIRES_DEFAULT
|
||||
|
||||
# settings dependent on the original task
|
||||
cloud_task["name"] = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_{task['name']}"
|
||||
cloud_task["task"] = OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR
|
||||
cloud_task["kwargs"] = {}
|
||||
cloud_task["kwargs"]["task_name"] = task["task"]
|
||||
|
||||
optional_fields = ["queue", "priority", "expires"]
|
||||
for field in optional_fields:
|
||||
if field in task["options"]:
|
||||
cloud_task["kwargs"][field] = task["options"][field]
|
||||
|
||||
return cloud_task
|
||||
|
||||
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 4
|
||||
|
||||
# tasks that only run in the cloud
|
||||
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered
|
||||
# by the DynamicTenantScheduler
|
||||
cloud_tasks_to_schedule: list[dict] = [
|
||||
cloud_tasks_to_schedule = [
|
||||
# cloud specific tasks
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
|
||||
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
"schedule": timedelta(hours=1),
|
||||
"schedule": timedelta(hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
# remaining tasks are cloud generators for per tenant tasks
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-prune",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=30 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-background-processes",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(minutes=5 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# generate our cloud and self-hosted beat tasks from the templates
|
||||
for beat_task_template in beat_task_templates:
|
||||
cloud_task = make_cloud_generator_task(beat_task_template)
|
||||
cloud_tasks_to_schedule.append(cloud_task)
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
cloud_tasks_to_schedule.append(
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(
|
||||
hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER
|
||||
), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
tasks_to_schedule: list[dict] = []
|
||||
|
||||
if not MULTI_TENANT:
|
||||
tasks_to_schedule = beat_task_templates
|
||||
tasks_to_schedule.extend(
|
||||
[
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-pruning",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"schedule": timedelta(minutes=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
|
||||
@@ -186,7 +186,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
pass
|
||||
|
||||
except TaskDependencyError:
|
||||
redis_connector.delete.set_fence(None)
|
||||
|
||||
@@ -228,15 +228,12 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_PERMISSIONS,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_PERMISSIONS,
|
||||
)
|
||||
|
||||
# set a basic fence to start
|
||||
redis_connector.permissions.set_active()
|
||||
@@ -260,10 +257,11 @@ def try_creating_permissions_sync_task(
|
||||
)
|
||||
|
||||
# fill in the celery task id
|
||||
redis_connector.permissions.set_active()
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
payload_id = payload.id
|
||||
payload_id = payload.celery_task_id
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
||||
return None
|
||||
@@ -292,8 +290,6 @@ def connector_permission_sync_generator_task(
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
payload_id: str | None = None
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
@@ -336,12 +332,9 @@ def connector_permission_sync_generator_task(
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
payload_id = payload.id
|
||||
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.permissions.fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
@@ -420,9 +413,7 @@ def connector_permission_sync_generator_task(
|
||||
redis_connector.permissions.generator_complete = tasks_generated
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
|
||||
)
|
||||
task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.taskset_clear()
|
||||
@@ -432,10 +423,6 @@ def connector_permission_sync_generator_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"Permission sync finished: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
@@ -672,7 +659,7 @@ def validate_permission_sync_fence(
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
# we're active if there are still tasks to run and those tasks all exist in celery
|
||||
# we're only active if tasks_scanned > 0 and tasks_not_in_celery == 0
|
||||
if tasks_scanned > 0 and tasks_not_in_celery == 0:
|
||||
redis_connector.permissions.set_active()
|
||||
return
|
||||
@@ -693,8 +680,7 @@ def validate_permission_sync_fence(
|
||||
"validate_permission_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
|
||||
@@ -2,17 +2,15 @@ import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
@@ -34,9 +32,7 @@ from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
@@ -53,8 +49,7 @@ from onyx.redis.redis_connector_ext_group_sync import (
|
||||
RedisConnectorExternalGroupSyncPayload,
|
||||
)
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -112,11 +107,11 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
# r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
@@ -154,32 +149,30 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
|
||||
lock_beat.reacquire()
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
payload_id = try_creating_external_group_sync_task(
|
||||
tasks_created = try_creating_external_group_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}"
|
||||
)
|
||||
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
lock_beat.reacquire()
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
|
||||
# clear fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_external_group_sync_fences(
|
||||
tenant_id, self.app, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Exception while validating external group sync fences"
|
||||
)
|
||||
# lock_beat.reacquire()
|
||||
# if not r.exists(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
|
||||
# # clear any indexing fences that don't have associated celery tasks in progress
|
||||
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# # or be currently executing
|
||||
# try:
|
||||
# validate_external_group_sync_fences(
|
||||
# tenant_id, self.app, r, r_celery, lock_beat
|
||||
# )
|
||||
# except Exception:
|
||||
# task_logger.exception(
|
||||
# "Exception while validating external group sync fences"
|
||||
# )
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=300)
|
||||
# r.set(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=60)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -198,11 +191,9 @@ def try_creating_external_group_sync_task(
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> str | None:
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
payload_id: str | None = None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
@@ -224,28 +215,11 @@ def try_creating_external_group_sync_task(
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
# Signal active before creating fence
|
||||
redis_connector.external_group_sync.set_active()
|
||||
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
id=make_short_id(),
|
||||
submitted=datetime.now(timezone.utc),
|
||||
started=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
@@ -260,10 +234,17 @@ def try_creating_external_group_sync_task(
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
)
|
||||
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
payload_id = payload.id
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||
@@ -273,7 +254,7 @@ def try_creating_external_group_sync_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return payload_id
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -331,8 +312,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
logger.info(
|
||||
f"connector_external_group_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
@@ -401,7 +381,7 @@ def connector_external_group_sync_generator_task(
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
f"Failed to run external group sync: cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -421,41 +401,32 @@ def connector_external_group_sync_generator_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"External group sync finished: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
)
|
||||
|
||||
|
||||
def validate_external_group_sync_fences(
|
||||
tenant_id: str | None,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
reserved_tasks = celery_get_unacked_task_ids(
|
||||
reserved_sync_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
|
||||
# validate all existing external group sync tasks
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if not key_str.startswith(RedisConnectorExternalGroupSync.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
validate_external_group_sync_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_tasks,
|
||||
r_celery,
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_external_group_sync_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_sync_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -464,6 +435,7 @@ def validate_external_group_sync_fence(
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
@@ -506,26 +478,26 @@ def validate_external_group_sync_fence(
|
||||
if not redis_connector.external_group_sync.fenced:
|
||||
return
|
||||
|
||||
try:
|
||||
payload = redis_connector.external_group_sync.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"validate_external_group_sync_fence - "
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.reset()
|
||||
return
|
||||
|
||||
payload = redis_connector.external_group_sync.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
if not payload.celery_task_id:
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
# the fence is just barely set up.
|
||||
# if redis_connector_index.active():
|
||||
# return
|
||||
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
"validate_external_group_sync_fence - "
|
||||
f"Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
)
|
||||
redis_connector.external_group_sync.reset()
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
@@ -555,8 +527,7 @@ def validate_external_group_sync_fence(
|
||||
"validate_external_group_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.reset()
|
||||
|
||||
@@ -423,8 +423,8 @@ def connector_indexing_task(
|
||||
# define a callback class
|
||||
callback = IndexingCallback(
|
||||
os.getppid(),
|
||||
redis_connector,
|
||||
redis_connector_index,
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector_index.generator_progress_key,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
|
||||
@@ -99,16 +99,16 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
redis_connector: RedisConnector,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
self.redis_connector: RedisConnector = redis_connector
|
||||
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
self.redis_client = redis_client
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
@@ -120,7 +120,7 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
self.last_parent_check = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_connector.stop.fenced:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -143,8 +143,6 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
# self.last_parent_check = now
|
||||
|
||||
try:
|
||||
self.redis_connector.prune.set_active()
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
|
||||
@@ -167,9 +165,7 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
self.redis_client.incrby(
|
||||
self.redis_connector_index.generator_progress_key, amount
|
||||
)
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
|
||||
@@ -1,39 +1,28 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector import mark_ccpair_as_pruned
|
||||
@@ -46,15 +35,10 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import pruning_ctx
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -109,8 +93,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
@@ -141,28 +123,13 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
if not _is_pruning_due(cc_pair):
|
||||
continue
|
||||
|
||||
payload_id = try_creating_prune_generator_task(
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
self.app, cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}"
|
||||
)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
lock_beat.reacquire()
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_PRUNING_FENCES):
|
||||
# clear any permission fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating pruning fences")
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PRUNING_FENCES, 1, ex=300)
|
||||
task_logger.info(f"Pruning queued: cc_pair={cc_pair.id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -182,7 +149,7 @@ def try_creating_prune_generator_task(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> str | None:
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the pruning generator task from being
|
||||
created, then creates the task.
|
||||
|
||||
@@ -201,7 +168,7 @@ def try_creating_prune_generator_task(
|
||||
|
||||
# we need to serialize starting pruning since it can be triggered either via
|
||||
# celery beat or manually (API call)
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_prune_generator_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -233,30 +200,7 @@ def try_creating_prune_generator_task(
|
||||
|
||||
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
try:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair.id,
|
||||
sync_type=SyncType.PRUNING,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
# signal active before the fence is set
|
||||
redis_connector.prune.set_active()
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPrunePayload(
|
||||
id=make_short_id(),
|
||||
submitted=datetime.now(timezone.utc),
|
||||
started=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
redis_connector.prune.set_fence(payload)
|
||||
|
||||
result = celery_app.send_task(
|
||||
celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair.id,
|
||||
@@ -269,11 +213,16 @@ def try_creating_prune_generator_task(
|
||||
priority=OnyxCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
# fill in the celery task id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.prune.set_fence(payload)
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair.id,
|
||||
sync_type=SyncType.PRUNING,
|
||||
)
|
||||
|
||||
payload_id = payload.id
|
||||
# set this only after all tasks have been added
|
||||
redis_connector.prune.set_fence(True)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
|
||||
return None
|
||||
@@ -281,7 +230,7 @@ def try_creating_prune_generator_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return payload_id
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -303,8 +252,6 @@ def connector_pruning_generator_task(
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
|
||||
payload_id: str | None = None
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
@@ -318,46 +265,6 @@ def connector_pruning_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_prune_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.prune.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector.prune.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_prune_generator_task - fence not found: "
|
||||
f"fence={redis_connector.prune.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector.prune.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_prune_generator_task: payload invalid or not found"
|
||||
)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_prune_generator_task - Waiting for fence: "
|
||||
f"fence={redis_connector.prune.fence_key}"
|
||||
)
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
payload_id = payload.id
|
||||
|
||||
logger.info(
|
||||
f"connector_prune_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.prune.fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
)
|
||||
break
|
||||
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
@@ -387,18 +294,6 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
return
|
||||
|
||||
payload = redis_connector.prune.payload
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
new_payload = RedisConnectorPrunePayload(
|
||||
id=payload.id,
|
||||
submitted=payload.submitted,
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=payload.celery_task_id,
|
||||
)
|
||||
redis_connector.prune.set_fence(new_payload)
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning generator running connector: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -412,13 +307,10 @@ def connector_pruning_generator_task(
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
|
||||
callback = IndexingCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
redis_connector_index,
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector.prune.generator_progress_key,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
@@ -465,9 +357,7 @@ def connector_pruning_generator_task(
|
||||
redis_connector.prune.generator_complete = tasks_generated
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Pruning exceptioned: cc_pair={cc_pair_id} "
|
||||
f"connector={connector_id} "
|
||||
f"payload_id={payload_id}"
|
||||
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.reset()
|
||||
@@ -476,9 +366,7 @@ def connector_pruning_generator_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning generator finished: cc_pair={cc_pair_id} payload_id={payload_id}"
|
||||
)
|
||||
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")
|
||||
|
||||
|
||||
"""Monitoring pruning utils, called in monitor_vespa_sync"""
|
||||
@@ -527,184 +415,4 @@ def monitor_ccpair_pruning_taskset(
|
||||
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.set_fence(None)
|
||||
|
||||
|
||||
def validate_pruning_fences(
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
# building lookup table can be expensive, so we won't bother
|
||||
# validating until the queue is small
|
||||
PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024
|
||||
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN:
|
||||
return
|
||||
|
||||
# the queue for a single pruning generator task
|
||||
reserved_generator_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
|
||||
# the queue for a reasonably large set of lightweight deletion tasks
|
||||
queued_upsert_tasks = celery_get_queued_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
|
||||
# Use replica for this because the worst thing that happens
|
||||
# is that we don't run the validation on this pass
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if not key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
validate_pruning_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_generator_tasks,
|
||||
queued_upsert_tasks,
|
||||
r,
|
||||
r_celery,
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def validate_pruning_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
queued_tasks: set[str],
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
) -> None:
|
||||
"""See validate_indexing_fence for an overall idea of validation flows.
|
||||
|
||||
queued_tasks: the celery queue of lightweight permission sync tasks
|
||||
reserved_tasks: prefetched tasks for sync task generator
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_pruning_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.prune.fenced:
|
||||
return
|
||||
|
||||
# in the cloud, the payload format may have changed ...
|
||||
# it's a little sloppy, but just reset the fence for now if that happens
|
||||
# TODO: add intentional cleanup/abort logic
|
||||
try:
|
||||
payload = redis_connector.prune.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"validate_pruning_fence - "
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.prune.reset()
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
if not payload.celery_task_id:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
# either the generator task must be in flight or its subtasks must be
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING,
|
||||
r_celery,
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
redis_connector.prune.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within a worker
|
||||
redis_connector.prune.set_active()
|
||||
return
|
||||
|
||||
# look up every task in the current taskset in the celery queue
|
||||
# every entry in the taskset should have an associated entry in the celery task queue
|
||||
# because we get the celery tasks first, the entries in our own pruning taskset
|
||||
# should be roughly a subset of the tasks in celery
|
||||
|
||||
# this check isn't very exact, but should be sufficient over a period of time
|
||||
# A single successful check over some number of attempts is sufficient.
|
||||
|
||||
# TODO: if the number of tasks in celery is much lower than than the taskset length
|
||||
# we might be able to shortcut the lookup since by definition some of the tasks
|
||||
# must not exist in celery.
|
||||
|
||||
tasks_scanned = 0
|
||||
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
|
||||
|
||||
for member in r.sscan_iter(redis_connector.prune.taskset_key):
|
||||
tasks_scanned += 1
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
continue
|
||||
|
||||
if member_str in reserved_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
task_logger.info(
|
||||
"validate_pruning_fence task check: "
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
# we're active if there are still tasks to run and those tasks all exist in celery
|
||||
if tasks_scanned > 0 and tasks_not_in_celery == 0:
|
||||
redis_connector.prune.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector.prune.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
task_logger.warning(
|
||||
"validate_pruning_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.reset()
|
||||
return
|
||||
redis_connector.prune.set_fence(False)
|
||||
|
||||
@@ -339,15 +339,11 @@ def try_generate_document_set_sync_tasks(
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
try:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
)
|
||||
# set this only after all tasks have been added
|
||||
rds.set_fence(tasks_generated)
|
||||
return tasks_generated
|
||||
@@ -415,15 +411,11 @@ def try_generate_user_group_sync_tasks(
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
try:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
)
|
||||
# set this only after all tasks have been added
|
||||
rug.set_fence(tasks_generated)
|
||||
|
||||
@@ -912,7 +904,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
|
||||
# use a lookup table to find active fences. We still have to verify the fence
|
||||
# exists since it is an optimization and not the source of truth.
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
keys = cast(set[Any], r.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
|
||||
@@ -140,7 +140,6 @@ def build_citations_user_message(
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
all_doc_useful: bool,
|
||||
history_message: str = "",
|
||||
context_type: str = "context documents",
|
||||
) -> HumanMessage:
|
||||
multilingual_expansion = get_multilingual_expansion()
|
||||
task_prompt_with_reminder = build_task_prompt_reminders(
|
||||
@@ -157,7 +156,6 @@ def build_citations_user_message(
|
||||
optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT
|
||||
|
||||
user_prompt = CITATIONS_PROMPT.format(
|
||||
context_type=context_type,
|
||||
optional_ignore_statement=optional_ignore,
|
||||
context_docs_str=context_docs_str,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
@@ -167,7 +165,6 @@ def build_citations_user_message(
|
||||
else:
|
||||
# if no context docs provided, assume we're in the tool calling flow
|
||||
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
|
||||
context_type=context_type,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=query,
|
||||
history_block=history_block,
|
||||
|
||||
@@ -13,6 +13,21 @@ AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
|
||||
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
|
||||
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = 30 # in seconds
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = 10 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = 25 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = 3 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = 12 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = 25 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = 8 # in seconds
|
||||
|
||||
#####
|
||||
# Agent Configs
|
||||
#####
|
||||
@@ -77,4 +92,76 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
|
||||
or AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH
|
||||
) # 2000
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
) # 3
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION
|
||||
) # 30
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
|
||||
) # 12
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
) # 6
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION
|
||||
) # 1
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION
|
||||
) # 4
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
) # 8
|
||||
|
||||
|
||||
GRAPH_VERSION_NAME: str = "a"
|
||||
|
||||
@@ -324,7 +324,6 @@ class OnyxRedisSignals:
|
||||
BLOCK_VALIDATE_PERMISSION_SYNC_FENCES = (
|
||||
"signal:block_validate_permission_sync_fences"
|
||||
)
|
||||
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
|
||||
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
|
||||
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -320,7 +319,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
|
||||
@@ -388,12 +386,4 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
|
||||
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
|
||||
yield doc_metadata_list
|
||||
|
||||
@@ -30,7 +30,6 @@ from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
@@ -322,7 +321,6 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
query = _build_time_range_query(time_range_start, time_range_end)
|
||||
doc_batch = []
|
||||
@@ -345,15 +343,6 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if len(doc_batch) > SLIM_BATCH_SIZE:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
@@ -379,10 +368,9 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
try:
|
||||
yield from self._fetch_slim_threads(start, end, callback=callback)
|
||||
yield from self._fetch_slim_threads(start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
|
||||
@@ -42,7 +42,6 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
@@ -565,7 +564,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_batch = []
|
||||
for file in self._fetch_drive_items(
|
||||
@@ -578,26 +576,15 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if len(slim_batch) >= SLIM_BATCH_SIZE:
|
||||
yield slim_batch
|
||||
slim_batch = []
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_extract_slim_docs_from_google_drive: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_extract_slim_docs_from_google_drive", 1)
|
||||
|
||||
yield slim_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
try:
|
||||
yield from self._extract_slim_docs_from_google_drive(
|
||||
start, end, callback=callback
|
||||
)
|
||||
yield from self._extract_slim_docs_from_google_drive(start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
|
||||
@@ -7,7 +7,6 @@ from pydantic import BaseModel
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
|
||||
SecondsSinceUnixEpoch = float
|
||||
@@ -64,7 +63,6 @@ class SlimConnector(BaseConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ from onyx.connectors.onyx_jira.utils import build_jira_url
|
||||
from onyx.connectors.onyx_jira.utils import extract_jira_project
|
||||
from onyx.connectors.onyx_jira.utils import extract_text_from_adf
|
||||
from onyx.connectors.onyx_jira.utils import get_comment_strs
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -246,7 +245,6 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
jql = f"project = {self.quoted_jira_project}"
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_record
|
||||
from onyx.connectors.salesforce.sqlite_functions import init_db
|
||||
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -177,7 +176,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
for parent_object_type in self.parent_object_list:
|
||||
|
||||
@@ -21,7 +21,6 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -243,7 +242,6 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
for post_id in get_all_post_ids(self.slab_bot_token):
|
||||
|
||||
@@ -27,7 +27,6 @@ from onyx.connectors.slack.utils import get_message_link
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -99,7 +98,6 @@ def get_channel_messages(
|
||||
channel: dict[str, Any],
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Generator[list[MessageType], None, None]:
|
||||
"""Get all messages in a channel"""
|
||||
# join so that the bot can access messages
|
||||
@@ -117,11 +115,6 @@ def get_channel_messages(
|
||||
oldest=oldest,
|
||||
latest=latest,
|
||||
):
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("get_channel_messages: Stop signal detected")
|
||||
|
||||
callback.progress("get_channel_messages", 0)
|
||||
yield cast(list[MessageType], result["messages"])
|
||||
|
||||
|
||||
@@ -332,7 +325,6 @@ def _get_all_doc_ids(
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
@@ -350,7 +342,6 @@ def _get_all_doc_ids(
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client,
|
||||
channel=channel,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
message_ts_set: set[str] = set()
|
||||
@@ -399,7 +390,6 @@ class SlackPollConnector(PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
@@ -408,7 +398,6 @@ class SlackPollConnector(PollConnector, SlimConnector):
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
|
||||
@@ -20,7 +20,6 @@ from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
@@ -406,7 +405,6 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
if self.content_type == "articles":
|
||||
|
||||
@@ -152,7 +152,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
# if not specified, all assistants are shown
|
||||
temperature_override_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
|
||||
@@ -228,7 +228,6 @@ def create_update_persona(
|
||||
num_chunks=create_persona_request.num_chunks,
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
is_default_persona=create_persona_request.is_default_persona,
|
||||
)
|
||||
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
|
||||
@@ -27,7 +27,6 @@ from langchain_core.prompt_values import PromptValue
|
||||
|
||||
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.configs.chat_configs import QA_TIMEOUT
|
||||
from onyx.configs.model_configs import (
|
||||
DISABLE_LITELLM_STREAMING,
|
||||
)
|
||||
@@ -36,7 +35,6 @@ from onyx.configs.model_configs import LITELLM_EXTRA_BODY
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
from onyx.server.utils import mask_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
@@ -52,6 +50,18 @@ litellm.telemetry = False
|
||||
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
|
||||
|
||||
|
||||
class LLMTimeoutError(Exception):
|
||||
"""
|
||||
Exception raised when an LLM call times out.
|
||||
"""
|
||||
|
||||
|
||||
class LLMRateLimitError(Exception):
|
||||
"""
|
||||
Exception raised when an LLM call is rate limited.
|
||||
"""
|
||||
|
||||
|
||||
def _base_msg_to_role(msg: BaseMessage) -> str:
|
||||
if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk):
|
||||
return "user"
|
||||
@@ -231,15 +241,15 @@ class DefaultMultiLLM(LLM):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
timeout: int,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
timeout: int | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
custom_llm_provider: str | None = None,
|
||||
temperature: float | None = None,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
extra_body: dict | None = LITELLM_EXTRA_BODY,
|
||||
@@ -247,16 +257,9 @@ class DefaultMultiLLM(LLM):
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
):
|
||||
self._timeout = timeout
|
||||
if timeout is None:
|
||||
if model_is_reasoning_model(model_name):
|
||||
self._timeout = QA_TIMEOUT * 10 # Reasoning models are slow
|
||||
else:
|
||||
self._timeout = QA_TIMEOUT
|
||||
|
||||
self._temperature = GEN_AI_TEMPERATURE if temperature is None else temperature
|
||||
|
||||
self._model_provider = model_provider
|
||||
self._model_version = model_name
|
||||
self._temperature = temperature
|
||||
self._api_key = api_key
|
||||
self._deployment_name = deployment_name
|
||||
self._api_base = api_base
|
||||
@@ -389,6 +392,7 @@ class DefaultMultiLLM(LLM):
|
||||
tool_choice: ToolChoiceOptions | None,
|
||||
stream: bool,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
|
||||
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
|
||||
# to a dict representation
|
||||
@@ -414,7 +418,7 @@ class DefaultMultiLLM(LLM):
|
||||
stream=stream,
|
||||
# model params
|
||||
temperature=0,
|
||||
timeout=self._timeout,
|
||||
timeout=timeout_override or self._timeout,
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
@@ -433,6 +437,12 @@ class DefaultMultiLLM(LLM):
|
||||
except Exception as e:
|
||||
self._record_error(processed_prompt, e)
|
||||
# for break pointing
|
||||
if isinstance(e, litellm.Timeout):
|
||||
raise LLMTimeoutError(e)
|
||||
|
||||
elif isinstance(e, litellm.RateLimitError):
|
||||
raise LLMRateLimitError(e)
|
||||
|
||||
raise e
|
||||
|
||||
@property
|
||||
@@ -453,6 +463,7 @@ class DefaultMultiLLM(LLM):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> BaseMessage:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
@@ -460,7 +471,12 @@ class DefaultMultiLLM(LLM):
|
||||
response = cast(
|
||||
litellm.ModelResponse,
|
||||
self._completion(
|
||||
prompt, tools, tool_choice, False, structured_response_format
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=False,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
),
|
||||
)
|
||||
choice = response.choices[0]
|
||||
@@ -478,19 +494,31 @@ class DefaultMultiLLM(LLM):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
if DISABLE_LITELLM_STREAMING:
|
||||
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
|
||||
yield self.invoke(
|
||||
prompt,
|
||||
tools,
|
||||
tool_choice,
|
||||
structured_response_format,
|
||||
timeout_override,
|
||||
)
|
||||
return
|
||||
|
||||
output = None
|
||||
response = cast(
|
||||
litellm.CustomStreamWrapper,
|
||||
self._completion(
|
||||
prompt, tools, tool_choice, True, structured_response_format
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
),
|
||||
)
|
||||
try:
|
||||
|
||||
@@ -81,6 +81,7 @@ class CustomModelServer(LLM):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> BaseMessage:
|
||||
return self._execute(prompt)
|
||||
|
||||
@@ -90,5 +91,6 @@ class CustomModelServer(LLM):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
yield self._execute(prompt)
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any
|
||||
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.chat_configs import QA_TIMEOUT
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
@@ -87,8 +88,8 @@ def get_llms_for_persona(
|
||||
|
||||
|
||||
def get_default_llms(
|
||||
timeout: int | None = None,
|
||||
temperature: float | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> tuple[LLM, LLM]:
|
||||
@@ -137,7 +138,7 @@ def get_llm(
|
||||
api_version: str | None = None,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
temperature: float | None = None,
|
||||
timeout: int | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM:
|
||||
|
||||
@@ -90,12 +90,13 @@ class LLM(abc.ABC):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> BaseMessage:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._invoke_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format
|
||||
prompt, tools, tool_choice, structured_response_format, timeout_override
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -105,6 +106,7 @@ class LLM(abc.ABC):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> BaseMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -114,12 +116,13 @@ class LLM(abc.ABC):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
messages = self._stream_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format
|
||||
prompt, tools, tool_choice, structured_response_format, timeout_override
|
||||
)
|
||||
|
||||
tokens = []
|
||||
@@ -138,5 +141,6 @@ class LLM(abc.ABC):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -29,11 +29,11 @@ OPENAI_PROVIDER_NAME = "openai"
|
||||
OPEN_AI_MODEL_NAMES = [
|
||||
"o3-mini",
|
||||
"o1-mini",
|
||||
"o1",
|
||||
"o1-preview",
|
||||
"o1-2024-12-17",
|
||||
"gpt-4",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"o1-preview",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-1106-preview",
|
||||
|
||||
@@ -543,14 +543,3 @@ def model_supports_image_input(model_name: str, model_provider: str) -> bool:
|
||||
f"Failed to get model object for {model_provider}/{model_name}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def model_is_reasoning_model(model_name: str) -> bool:
|
||||
_REASONING_MODEL_NAMES = [
|
||||
"o1",
|
||||
"o1-mini",
|
||||
"o3-mini",
|
||||
"deepseek-reasoner",
|
||||
"deepseek-r1",
|
||||
]
|
||||
return model_name.lower() in _REASONING_MODEL_NAMES
|
||||
|
||||
@@ -5,8 +5,6 @@ UNKNOWN_ANSWER = "I do not have enough information to answer this question."
|
||||
NO_RECOVERED_DOCS = "No relevant information recovered"
|
||||
YES = "yes"
|
||||
NO = "no"
|
||||
|
||||
|
||||
# Framing/Support/Template Prompts
|
||||
HISTORY_FRAMING_PROMPT = f"""
|
||||
For more context, here is the history of the conversation so far that preceded this question:
|
||||
|
||||
@@ -91,7 +91,7 @@ SAMPLE RESPONSE:
|
||||
# similar to the chat flow, but with the option of including a
|
||||
# "conversation history" block
|
||||
CITATIONS_PROMPT = f"""
|
||||
Refer to the following {{context_type}} when responding to me.{DEFAULT_IGNORE_STATEMENT}
|
||||
Refer to the following context documents when responding to me.{DEFAULT_IGNORE_STATEMENT}
|
||||
|
||||
CONTEXT:
|
||||
{GENERAL_SEP_PAT}
|
||||
@@ -108,7 +108,7 @@ CONTEXT:
|
||||
# NOTE: need to add the extra line about "getting right to the point" since the
|
||||
# tool calling models from OpenAI tend to be more verbose
|
||||
CITATIONS_PROMPT_FOR_TOOL_CALLING = f"""
|
||||
Refer to the provided {{context_type}} when responding to me.{DEFAULT_IGNORE_STATEMENT} \
|
||||
Refer to the provided context documents when responding to me.{DEFAULT_IGNORE_STATEMENT} \
|
||||
You should always get right to the point, and never use extraneous language.
|
||||
|
||||
{{history_block}}{{task_prompt}}
|
||||
|
||||
@@ -80,8 +80,7 @@ class RedisConnectorPermissionSync:
|
||||
def get_active_task_count(self) -> int:
|
||||
"""Count of active permission sync tasks"""
|
||||
count = 0
|
||||
for _ in self.redis.sscan_iter(
|
||||
OnyxRedisConstants.ACTIVE_FENCES,
|
||||
for _ in self.redis.scan_iter(
|
||||
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
@@ -7,12 +8,10 @@ from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
|
||||
|
||||
class RedisConnectorExternalGroupSyncPayload(BaseModel):
|
||||
id: str
|
||||
submitted: datetime
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
@@ -38,12 +37,6 @@ class RedisConnectorExternalGroupSync:
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorexternalgroupsync_taskset
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorexternalgroupsync+sub
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
@@ -57,7 +50,6 @@ class RedisConnectorExternalGroupSync:
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
@@ -74,8 +66,7 @@ class RedisConnectorExternalGroupSync:
|
||||
def get_active_task_count(self) -> int:
|
||||
"""Count of active external group syncing tasks"""
|
||||
count = 0
|
||||
for _ in self.redis.sscan_iter(
|
||||
OnyxRedisConstants.ACTIVE_FENCES,
|
||||
for _ in self.redis.scan_iter(
|
||||
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
@@ -92,11 +83,10 @@ class RedisConnectorExternalGroupSync:
|
||||
@property
|
||||
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_raw = self.redis.get(self.fence_key)
|
||||
if fence_raw is None:
|
||||
fence_bytes = cast(Any, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_bytes = cast(bytes, fence_raw)
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorExternalGroupSyncPayload.model_validate_json(
|
||||
cast(str, fence_str)
|
||||
@@ -109,26 +99,10 @@ class RedisConnectorExternalGroupSync:
|
||||
payload: RedisConnectorExternalGroupSyncPayload | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the permissioning flow from getting cleaned up within
|
||||
the expiration time.
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
@@ -164,8 +138,6 @@ class RedisConnectorExternalGroupSync:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.active_key)
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
@@ -180,9 +152,6 @@ class RedisConnectorExternalGroupSync:
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorExternalGroupSync.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorExternalGroupSync.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -18,13 +16,6 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
|
||||
|
||||
class RedisConnectorPrunePayload(BaseModel):
|
||||
id: str
|
||||
submitted: datetime
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
|
||||
|
||||
class RedisConnectorPrune:
|
||||
"""Manages interactions with redis for pruning tasks. Should only be accessed
|
||||
through RedisConnector."""
|
||||
@@ -45,12 +36,6 @@ class RedisConnectorPrune:
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
@@ -64,7 +49,6 @@ class RedisConnectorPrune:
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
@@ -81,10 +65,8 @@ class RedisConnectorPrune:
|
||||
def get_active_task_count(self) -> int:
|
||||
"""Count of active pruning tasks"""
|
||||
count = 0
|
||||
for _ in self.redis.sscan_iter(
|
||||
OnyxRedisConstants.ACTIVE_FENCES,
|
||||
RedisConnectorPrune.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
for key in self.redis.scan_iter(
|
||||
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
count += 1
|
||||
return count
|
||||
@@ -96,44 +78,15 @@ class RedisConnectorPrune:
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorPrunePayload | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorPrunePayload.model_validate_json(cast(str, fence_str))
|
||||
|
||||
return payload
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
payload: RedisConnectorPrunePayload | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
def set_fence(self, value: bool) -> None:
|
||||
if not value:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
self.redis.set(self.fence_key, 0)
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the permissioning flow from getting cleaned up within
|
||||
the expiration time.
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
"""the fence payload is an int representing the starting number of
|
||||
@@ -209,7 +162,6 @@ class RedisConnectorPrune:
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.active_key)
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
@@ -224,9 +176,6 @@ class RedisConnectorPrune:
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorPrune.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPrune.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
@@ -368,17 +368,15 @@ def prune_cc_pair(
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
payload_id = try_creating_prune_generator_task(
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
primary_app, cc_pair, db_session, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
)
|
||||
if not payload_id:
|
||||
if not tasks_created:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="Pruning task creation failed.",
|
||||
)
|
||||
|
||||
logger.info(f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}")
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the pruning task.",
|
||||
@@ -516,17 +514,15 @@ def sync_cc_pair_groups(
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
payload_id = try_creating_external_group_sync_task(
|
||||
tasks_created = try_creating_external_group_sync_task(
|
||||
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
)
|
||||
if not payload_id:
|
||||
if not tasks_created:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="External group sync task creation failed.",
|
||||
)
|
||||
|
||||
logger.info(f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}")
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the external group sync task.",
|
||||
|
||||
@@ -279,5 +279,4 @@ class InternetSearchTool(Tool):
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
answer_style_config=self.answer_style_config,
|
||||
prompt_config=self.prompt_config,
|
||||
context_type="internet search results",
|
||||
)
|
||||
|
||||
@@ -25,7 +25,6 @@ def build_next_prompt_for_search_like_tool(
|
||||
using_tool_calling_llm: bool,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
context_type: str = "context documents",
|
||||
) -> AnswerPromptBuilder:
|
||||
if not using_tool_calling_llm:
|
||||
final_context_docs_response = next(
|
||||
@@ -59,7 +58,6 @@ def build_next_prompt_for_search_like_tool(
|
||||
else False
|
||||
),
|
||||
history_message=prompt_builder.single_message_history or "",
|
||||
context_type=context_type,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"rsc": true,
|
||||
"tsx": true,
|
||||
"tailwind": {
|
||||
"config": "tailwind-themes/tailwind.config.js",
|
||||
"config": "tailwind.config.js",
|
||||
"css": "src/app/globals.css",
|
||||
"baseColor": "neutral",
|
||||
"cssVariables": false,
|
||||
|
||||
1021
web/package-lock.json
generated
1021
web/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@
|
||||
"version-comment": "version field must be SemVer or chromatic will barf",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev --turbo",
|
||||
"dev": "next dev --turbopack",
|
||||
"build": "next build",
|
||||
"start": "next start",
|
||||
"lint": "next lint",
|
||||
@@ -21,17 +21,17 @@
|
||||
"@radix-ui/react-accordion": "^1.2.2",
|
||||
"@radix-ui/react-checkbox": "^1.1.2",
|
||||
"@radix-ui/react-collapsible": "^1.1.2",
|
||||
"@radix-ui/react-dialog": "^1.1.6",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.6",
|
||||
"@radix-ui/react-dialog": "^1.1.2",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.4",
|
||||
"@radix-ui/react-label": "^2.1.1",
|
||||
"@radix-ui/react-popover": "^1.1.6",
|
||||
"@radix-ui/react-popover": "^1.1.2",
|
||||
"@radix-ui/react-radio-group": "^1.2.2",
|
||||
"@radix-ui/react-scroll-area": "^1.2.2",
|
||||
"@radix-ui/react-select": "^2.1.6",
|
||||
"@radix-ui/react-select": "^2.1.2",
|
||||
"@radix-ui/react-separator": "^1.1.0",
|
||||
"@radix-ui/react-slider": "^1.2.2",
|
||||
"@radix-ui/react-slot": "^1.1.2",
|
||||
"@radix-ui/react-switch": "^1.1.3",
|
||||
"@radix-ui/react-slot": "^1.1.0",
|
||||
"@radix-ui/react-switch": "^1.1.1",
|
||||
"@radix-ui/react-tabs": "^1.1.1",
|
||||
"@radix-ui/react-tooltip": "^1.1.3",
|
||||
"@sentry/nextjs": "^8.50.0",
|
||||
@@ -56,7 +56,6 @@
|
||||
"lucide-react": "^0.454.0",
|
||||
"mdast-util-find-and-replace": "^3.0.1",
|
||||
"next": "^15.0.2",
|
||||
"next-themes": "^0.4.4",
|
||||
"npm": "^10.8.0",
|
||||
"postcss": "^8.4.31",
|
||||
"posthog-js": "^1.176.0",
|
||||
|
||||
BIN
web/public/LiteLLM.jpg
Normal file
BIN
web/public/LiteLLM.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 12 KiB |
BIN
web/public/discord.png
Normal file
BIN
web/public/discord.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 9.9 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 4.0 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 89 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 4.8 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 14 KiB |
@@ -27,12 +27,8 @@ function SourceTile({
|
||||
w-40
|
||||
cursor-pointer
|
||||
shadow-md
|
||||
hover:bg-accent-background-hovered
|
||||
${
|
||||
preSelect
|
||||
? "bg-accent-background-hovered subtle-pulse"
|
||||
: "bg-accent-background"
|
||||
}
|
||||
hover:bg-hover
|
||||
${preSelect ? "bg-hover subtle-pulse" : "bg-hover-light"}
|
||||
`}
|
||||
href={sourceMetadata.adminUrl}
|
||||
>
|
||||
|
||||
@@ -56,7 +56,7 @@ function NewApiKeyModal({
|
||||
<div className="flex mt-2">
|
||||
<b className="my-auto break-all">{apiKey}</b>
|
||||
<div
|
||||
className="ml-2 my-auto p-2 hover:bg-accent-background-hovered rounded cursor-pointer"
|
||||
className="ml-2 my-auto p-2 hover:bg-hover rounded cursor-pointer"
|
||||
onClick={() => {
|
||||
setCopyClicked(true);
|
||||
navigator.clipboard.writeText(apiKey);
|
||||
@@ -112,10 +112,7 @@ function Main() {
|
||||
}
|
||||
|
||||
const newApiKeyButton = (
|
||||
<CreateButton
|
||||
onClick={() => setShowCreateUpdateForm(true)}
|
||||
text="Create API Key"
|
||||
/>
|
||||
<CreateButton href="/admin/api-key/new" text="Create API Key" />
|
||||
);
|
||||
|
||||
if (apiKeys.length === 0) {
|
||||
@@ -182,7 +179,7 @@ function Main() {
|
||||
flex
|
||||
mb-1
|
||||
w-fit
|
||||
hover:bg-accent-background-hovered cursor-pointer
|
||||
hover:bg-hover cursor-pointer
|
||||
p-2
|
||||
rounded-lg
|
||||
border-border
|
||||
@@ -206,7 +203,7 @@ function Main() {
|
||||
flex
|
||||
mb-1
|
||||
w-fit
|
||||
hover:bg-accent-background-hovered cursor-pointer
|
||||
hover:bg-hover cursor-pointer
|
||||
p-2
|
||||
rounded-lg
|
||||
border-border
|
||||
|
||||
@@ -825,7 +825,10 @@ export function AssistantEditor({
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-sm text-neutral-700 dark:text-neutral-400">
|
||||
<p
|
||||
className="text-sm text-subtle"
|
||||
style={{ color: "rgb(113, 114, 121)" }}
|
||||
>
|
||||
Attach additional unique knowledge to this assistant
|
||||
</p>
|
||||
</div>
|
||||
@@ -1214,7 +1217,7 @@ export function AssistantEditor({
|
||||
setFieldValue("label_ids", newLabelIds);
|
||||
}}
|
||||
itemComponent={({ option }) => (
|
||||
<div className="flex items-center justify-between px-4 py-3 text-sm hover:bg-accent-background-hovered cursor-pointer border-b border-border last:border-b-0">
|
||||
<div className="flex items-center justify-between px-4 py-3 text-sm hover:bg-hover cursor-pointer border-b border-border last:border-b-0">
|
||||
<div
|
||||
className="flex-grow"
|
||||
onClick={() => {
|
||||
@@ -1353,7 +1356,7 @@ export function AssistantEditor({
|
||||
</>
|
||||
)}
|
||||
|
||||
<div className="mt-12 gap-x-2 w-full justify-end flex">
|
||||
<div className="mt-12 gap-x-2 w-full justify-end flex">
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={isSubmitting || isRequestSuccessful}
|
||||
|
||||
@@ -31,7 +31,7 @@ export function HidableSection({
|
||||
return (
|
||||
<div>
|
||||
<div
|
||||
className="flex hover:bg-accent-background rounded cursor-pointer p-2"
|
||||
className="flex hover:bg-hover-light rounded cursor-pointer p-2"
|
||||
onClick={() => setIsHidden(!isHidden)}
|
||||
>
|
||||
<SectionHeader includeMargin={false}>{sectionTitle}</SectionHeader>
|
||||
|
||||
@@ -187,9 +187,7 @@ export function PersonasTable() {
|
||||
}
|
||||
}}
|
||||
className={`px-1 py-0.5 rounded flex ${
|
||||
isEditable
|
||||
? "hover:bg-accent-background-hovered cursor-pointer"
|
||||
: ""
|
||||
isEditable ? "hover:bg-hover cursor-pointer" : ""
|
||||
} select-none w-fit`}
|
||||
>
|
||||
<div className="my-auto w-12">
|
||||
@@ -207,7 +205,7 @@ export function PersonasTable() {
|
||||
<div className="mr-auto my-auto">
|
||||
{!persona.builtin_persona && isEditable ? (
|
||||
<div
|
||||
className="hover:bg-accent-background-hovered rounded p-1 cursor-pointer"
|
||||
className="hover:bg-hover rounded p-1 cursor-pointer"
|
||||
onClick={() => openDeleteModal(persona)}
|
||||
>
|
||||
<TrashIcon />
|
||||
|
||||
@@ -65,7 +65,7 @@ export default function StarterMessagesList({
|
||||
onClick={() => {
|
||||
arrayHelpers.remove(index);
|
||||
}}
|
||||
className={`text-text-400 hover:text-red-500 ${
|
||||
className={`text-gray-400 hover:text-red-500 ${
|
||||
index === values.length - 1 && !starterMessage.message
|
||||
? "opacity-50 cursor-not-allowed"
|
||||
: ""
|
||||
@@ -105,7 +105,7 @@ export default function StarterMessagesList({
|
||||
4 ||
|
||||
isRefreshing ||
|
||||
!autoStarterMessageEnabled
|
||||
? "bg-background-800 text-text-300 cursor-not-allowed"
|
||||
? "bg-neutral-800 text-neutral-300 cursor-not-allowed"
|
||||
: ""
|
||||
}
|
||||
`}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
"use client";
|
||||
import { PersonasTable } from "./PersonaTable";
|
||||
import Text from "@/components/ui/text";
|
||||
import Title from "@/components/ui/title";
|
||||
|
||||
@@ -92,9 +92,9 @@ export const ExistingSlackBotForm = ({
|
||||
|
||||
<div className="flex flex-col" ref={dropdownRef}>
|
||||
<div className="flex items-center gap-4">
|
||||
<div className="border rounded-lg border-background-200">
|
||||
<div className="border rounded-lg border-gray-200">
|
||||
<div
|
||||
className="flex items-center gap-2 cursor-pointer hover:bg-background-100 p-2"
|
||||
className="flex items-center gap-2 cursor-pointer hover:bg-gray-100 p-2"
|
||||
onClick={() => setIsExpanded(!isExpanded)}
|
||||
>
|
||||
{isExpanded ? (
|
||||
@@ -117,7 +117,7 @@ export const ExistingSlackBotForm = ({
|
||||
</div>
|
||||
|
||||
{isExpanded && (
|
||||
<div className="bg-white border rounded-lg border-background-200 shadow-lg absolute mt-12 right-0 z-10 w-full md:w-3/4 lg:w-1/2">
|
||||
<div className="bg-white border rounded-lg border-gray-200 shadow-lg absolute mt-12 right-0 z-10 w-full md:w-3/4 lg:w-1/2">
|
||||
<div className="p-4">
|
||||
<SlackTokensForm
|
||||
isUpdate={true}
|
||||
@@ -134,7 +134,7 @@ export const ExistingSlackBotForm = ({
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-2">
|
||||
<div className="inline-block border rounded-lg border-background-200 p-2">
|
||||
<div className="inline-block border rounded-lg border-gray-200 p-2">
|
||||
<Checkbox
|
||||
label="Enabled"
|
||||
checked={formValues.enabled}
|
||||
|
||||
@@ -613,7 +613,7 @@ export function SlackChannelConfigFormFields({
|
||||
<Link
|
||||
key={ccpairinfo.id}
|
||||
href={`/admin/connector/${ccpairinfo.id}`}
|
||||
className="flex items-center p-2 rounded-md hover:bg-background-100 transition-colors"
|
||||
className="flex items-center p-2 rounded-md hover:bg-gray-100 transition-colors"
|
||||
>
|
||||
<div className="mr-2">
|
||||
<SourceIcon
|
||||
|
||||
@@ -84,7 +84,7 @@ function Main() {
|
||||
{isApiKeySet ? (
|
||||
<div className="w-full p-3 border rounded-md bg-background text-text flex items-center">
|
||||
<span className="flex-grow">••••••••••••••••</span>
|
||||
<Lock className="h-5 w-5 text-text-400" />
|
||||
<Lock className="h-5 w-5 text-gray-400" />
|
||||
</div>
|
||||
) : (
|
||||
<input
|
||||
|
||||
@@ -25,7 +25,8 @@ function LLMProviderUpdateModal({
|
||||
}) {
|
||||
const providerName = existingLlmProvider?.name
|
||||
? `"${existingLlmProvider.name}"`
|
||||
: llmProviderDescriptor?.display_name ||
|
||||
: null ||
|
||||
llmProviderDescriptor?.display_name ||
|
||||
llmProviderDescriptor?.name ||
|
||||
"Custom LLM Provider";
|
||||
return (
|
||||
@@ -74,7 +75,7 @@ function LLMProviderDisplay({
|
||||
return (
|
||||
<div>
|
||||
{popup}
|
||||
<div className="border border-border p-3 dark:bg-neutral-800 dark:border-neutral-700 rounded w-96 flex shadow-md">
|
||||
<div className="border border-border p-3 rounded w-96 flex shadow-md">
|
||||
<div className="my-auto">
|
||||
<div className="font-bold">{providerName} </div>
|
||||
<div className="text-xs italic">({existingLlmProvider.provider})</div>
|
||||
@@ -112,7 +113,7 @@ function LLMProviderDisplay({
|
||||
{existingLlmProvider && (
|
||||
<div className="my-auto ml-3">
|
||||
{existingLlmProvider.is_default_provider ? (
|
||||
<Badge variant="agent">Default</Badge>
|
||||
<Badge variant="orange">Default</Badge>
|
||||
) : (
|
||||
<Badge variant="success">Enabled</Badge>
|
||||
)}
|
||||
|
||||
@@ -348,7 +348,7 @@ export function CustomLLMProviderUpdateForm({
|
||||
</div>
|
||||
<div className="my-auto">
|
||||
<FiX
|
||||
className="my-auto w-10 h-10 cursor-pointer hover:bg-accent-background-hovered rounded p-2"
|
||||
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
|
||||
onClick={() => arrayHelpers.remove(index)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -73,7 +73,7 @@ function DefaultLLMProviderDisplay({
|
||||
return (
|
||||
<div>
|
||||
{popup}
|
||||
<div className="border border-border p-3 dark:bg-neutral-800 dark:border-neutral-700 rounded w-96 flex shadow-md">
|
||||
<div className="border border-border p-3 rounded w-96 flex shadow-md">
|
||||
<div className="my-auto">
|
||||
<div className="font-bold">{providerName} </div>
|
||||
</div>
|
||||
|
||||
@@ -12,7 +12,6 @@ import {
|
||||
OpenSourceIcon,
|
||||
AnthropicSVG,
|
||||
IconProps,
|
||||
OpenAIISVG,
|
||||
} from "@/components/icons/icons";
|
||||
import { FaRobot } from "react-icons/fa";
|
||||
|
||||
@@ -105,7 +104,7 @@ export const getProviderIcon = (providerName: string, modelName?: string) => {
|
||||
switch (providerName) {
|
||||
case "openai":
|
||||
// Special cases for openai based on modelName
|
||||
return modelNameToIcon(modelName || "", OpenAIISVG);
|
||||
return modelNameToIcon(modelName || "", OpenAIIcon);
|
||||
case "anthropic":
|
||||
return AnthropicSVG;
|
||||
case "bedrock":
|
||||
|
||||
@@ -111,14 +111,14 @@ function Main() {
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<Text className="font-semibold">Reranking Model</Text>
|
||||
<Text className="text-text-700">
|
||||
<Text className="text-gray-700">
|
||||
{searchSettings.rerank_model_name || "Not set"}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Text className="font-semibold">Results to Rerank</Text>
|
||||
<Text className="text-text-700">
|
||||
<Text className="text-gray-700">
|
||||
{searchSettings.num_rerank}
|
||||
</Text>
|
||||
</div>
|
||||
@@ -127,7 +127,7 @@ function Main() {
|
||||
<Text className="font-semibold">
|
||||
Multilingual Expansion
|
||||
</Text>
|
||||
<Text className="text-text-700">
|
||||
<Text className="text-gray-700">
|
||||
{searchSettings.multilingual_expansion.length > 0
|
||||
? searchSettings.multilingual_expansion.join(", ")
|
||||
: "None"}
|
||||
@@ -136,7 +136,7 @@ function Main() {
|
||||
|
||||
<div>
|
||||
<Text className="font-semibold">Multipass Indexing</Text>
|
||||
<Text className="text-text-700">
|
||||
<Text className="text-gray-700">
|
||||
{searchSettings.multipass_indexing
|
||||
? "Enabled"
|
||||
: "Disabled"}
|
||||
@@ -147,7 +147,7 @@ function Main() {
|
||||
<Text className="font-semibold">
|
||||
Disable Reranking for Streaming
|
||||
</Text>
|
||||
<Text className="text-text-700">
|
||||
<Text className="text-gray-700">
|
||||
{searchSettings.disable_rerank_for_streaming
|
||||
? "Yes"
|
||||
: "No"}
|
||||
|
||||
@@ -149,7 +149,7 @@ export function AdvancedConfigDisplay({
|
||||
<>
|
||||
<Title className="mt-8 mb-2">Advanced Configuration</Title>
|
||||
<CardSection>
|
||||
<ul className="w-full text-sm divide-y divide-neutral-200 dark:divide-neutral-700">
|
||||
<ul className="w-full text-sm divide-y divide-background-200 dark:divide-background-700">
|
||||
{pruneFreq && (
|
||||
<li
|
||||
key={0}
|
||||
|
||||
@@ -11,7 +11,7 @@ export default function DeletionErrorStatus({
|
||||
<h3 className="text-base font-medium">Deletion Error</h3>
|
||||
<div className="ml-2 relative group">
|
||||
<FiInfo className="h-4 w-4 text-error-600 cursor-help" />
|
||||
<div className="absolute z-10 w-64 p-2 mt-2 text-sm bg-white rounded-md shadow-lg opacity-0 group-hover:opacity-100 transition-opacity duration-300 border border-background-200">
|
||||
<div className="absolute z-10 w-64 p-2 mt-2 text-sm bg-white rounded-md shadow-lg opacity-0 group-hover:opacity-100 transition-opacity duration-300 border border-gray-200">
|
||||
This error occurred while attempting to delete the connector. You
|
||||
may re-attempt a deletion by clicking the "Delete" button.
|
||||
</div>
|
||||
|
||||
@@ -293,7 +293,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
<b className="text-emphasis">{ccPair.num_docs_indexed}</b>
|
||||
</div>
|
||||
{!ccPair.is_editable_for_current_user && (
|
||||
<div className="text-sm mt-2 text-text-500 italic">
|
||||
<div className="text-sm mt-2 text-neutral-500 italic">
|
||||
{ccPair.access_type === "public"
|
||||
? "Public connectors are not editable by curators."
|
||||
: ccPair.access_type === "sync"
|
||||
|
||||
@@ -62,7 +62,6 @@ import {
|
||||
} from "@/lib/connectors/oauth";
|
||||
import { CreateStdOAuthCredential } from "@/components/credentials/actions/CreateStdOAuthCredential";
|
||||
import { Spinner } from "@/components/Spinner";
|
||||
import { Button } from "@/components/ui/button";
|
||||
export interface AdvancedConfig {
|
||||
refreshFreq: number;
|
||||
pruneFreq: number;
|
||||
@@ -465,9 +464,8 @@ export default function AddConnector({
|
||||
{!createCredentialFormToggle && (
|
||||
<div className="mt-6 flex space-x-4">
|
||||
{/* Button to pop up a form to manually enter credentials */}
|
||||
<Button
|
||||
variant="secondary"
|
||||
className="mt-6 text-sm mr-4"
|
||||
<button
|
||||
className="mt-6 text-sm bg-background-900 px-2 py-1.5 flex text-text-200 flex-none rounded mr-4"
|
||||
onClick={async () => {
|
||||
if (oauthDetails && oauthDetails.oauth_enabled) {
|
||||
if (oauthDetails.additional_kwargs.length > 0) {
|
||||
@@ -497,7 +495,7 @@ export default function AddConnector({
|
||||
}}
|
||||
>
|
||||
Create New
|
||||
</Button>
|
||||
</button>
|
||||
{/* Button to sign in via OAuth */}
|
||||
{oauthSupportedSources.includes(connector) &&
|
||||
(NEXT_PUBLIC_CLOUD_ENABLED ||
|
||||
|
||||
@@ -50,10 +50,11 @@ const NavigationRow = ({
|
||||
</SquareNavigationButton>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex justify-center">
|
||||
{(formStep > 0 || noCredentials) && (
|
||||
<SquareNavigationButton
|
||||
className="bg-agent text-white py-2.5 px-3.5 disabled:opacity-50"
|
||||
className="bg-accent text-white py-2.5 px-3.5 disabled:opacity-50"
|
||||
disabled={!isValid}
|
||||
onClick={onSubmit}
|
||||
>
|
||||
@@ -62,6 +63,7 @@ const NavigationRow = ({
|
||||
</SquareNavigationButton>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end">
|
||||
{formStep === 0 && (
|
||||
<SquareNavigationButton
|
||||
|
||||
@@ -96,7 +96,7 @@ export default function Sidebar() {
|
||||
<div className="mx-auto w-full max-w-2xl px-4 py-8">
|
||||
<div className="relative">
|
||||
{connector != "file" && (
|
||||
<div className="absolute h-[85%] left-[6px] top-[8px] bottom-0 w-0.5 bg-background-300"></div>
|
||||
<div className="absolute h-[85%] left-[6px] top-[8px] bottom-0 w-0.5 bg-gray-300"></div>
|
||||
)}
|
||||
{settingSteps.map((step, index) => {
|
||||
const allowed =
|
||||
@@ -119,7 +119,7 @@ export default function Sidebar() {
|
||||
<div className="flex-shrink-0 mr-4 z-10">
|
||||
<div
|
||||
className={`rounded-full h-3.5 w-3.5 flex items-center justify-center ${
|
||||
allowed ? "bg-blue-500" : "bg-background-300"
|
||||
allowed ? "bg-blue-500" : "bg-gray-300"
|
||||
}`}
|
||||
>
|
||||
{formStep === index && (
|
||||
@@ -129,7 +129,7 @@ export default function Sidebar() {
|
||||
</div>
|
||||
<div
|
||||
className={`${
|
||||
index <= formStep ? "text-text-800" : "text-text-500"
|
||||
index <= formStep ? "text-gray-800" : "text-gray-500"
|
||||
}`}
|
||||
>
|
||||
{step}
|
||||
|
||||
@@ -16,7 +16,7 @@ export default function NumberInput({
|
||||
}) {
|
||||
return (
|
||||
<div className="w-full flex flex-col">
|
||||
<label className="block text-base font-medium text-text-700 dark:text-neutral-100 mb-1">
|
||||
<label className="block text-base font-medium text-text-700 mb-1">
|
||||
{label}
|
||||
{optional && <span className="text-text-500 ml-1">(optional)</span>}
|
||||
</label>
|
||||
@@ -27,10 +27,10 @@ export default function NumberInput({
|
||||
name={name}
|
||||
min="-1"
|
||||
className={`mt-2 block w-full px-3 py-2
|
||||
bg-[#fff] dark:bg-transparent border border-background-300 rounded-md
|
||||
text-sm shadow-sm placeholder-text-400
|
||||
bg-white border border-gray-300 rounded-md
|
||||
text-sm shadow-sm placeholder-gray-400
|
||||
focus:outline-none focus:border-sky-500 focus:ring-1 focus:ring-sky-500
|
||||
disabled:bg-background-50 disabled:text-text-500 disabled:border-background-200 disabled:shadow-none
|
||||
disabled:bg-gray-50 disabled:text-gray-500 disabled:border-gray-200 disabled:shadow-none
|
||||
invalid:border-pink-500 invalid:text-pink-600
|
||||
focus:invalid:border-pink-500 focus:invalid:ring-pink-500`}
|
||||
/>
|
||||
|
||||
@@ -30,10 +30,10 @@ export default function NumberInput({
|
||||
min="-1"
|
||||
value={value === 0 && showNeverIfZero ? "Never" : value}
|
||||
className={`mt-2 block w-full px-3 py-2
|
||||
bg-white border border-background-300 rounded-md
|
||||
text-sm shadow-sm placeholder-text-400
|
||||
bg-white border border-gray-300 rounded-md
|
||||
text-sm shadow-sm placeholder-gray-400
|
||||
focus:outline-none focus:border-sky-500 focus:ring-1 focus:ring-sky-500
|
||||
disabled:bg-background-50 disabled:text-text-500 disabled:border-background-200 disabled:shadow-none
|
||||
disabled:bg-gray-50 disabled:text-gray-500 disabled:border-gray-200 disabled:shadow-none
|
||||
invalid:border-pink-500 invalid:text-pink-600
|
||||
focus:invalid:border-pink-500 focus:invalid:ring-pink-500`}
|
||||
/>
|
||||
|
||||
@@ -34,9 +34,9 @@ export const DriveJsonUpload = ({
|
||||
<>
|
||||
<input
|
||||
className={
|
||||
"mr-3 text-sm text-text-900 border border-background-300 " +
|
||||
"cursor-pointer bg-backgrournd dark:text-text-400 focus:outline-none " +
|
||||
"dark:bg-background-700 dark:border-background-600 dark:placeholder-text-400"
|
||||
"mr-3 text-sm text-gray-900 border border-gray-300 " +
|
||||
"cursor-pointer bg-backgrournd dark:text-gray-400 focus:outline-none " +
|
||||
"dark:bg-gray-700 dark:border-gray-600 dark:placeholder-gray-400"
|
||||
}
|
||||
type="file"
|
||||
accept=".json"
|
||||
|
||||
@@ -34,9 +34,9 @@ const DriveJsonUpload = ({
|
||||
<>
|
||||
<input
|
||||
className={
|
||||
"mr-3 text-sm text-text-900 border border-background-300 overflow-visible " +
|
||||
"cursor-pointer bg-background dark:text-text-400 focus:outline-none " +
|
||||
"dark:bg-background-700 dark:border-background-600 dark:placeholder-text-400"
|
||||
"mr-3 text-sm text-gray-900 border border-gray-300 overflow-visible " +
|
||||
"cursor-pointer bg-background dark:text-gray-400 focus:outline-none " +
|
||||
"dark:bg-gray-700 dark:border-gray-600 dark:placeholder-gray-400"
|
||||
}
|
||||
type="file"
|
||||
accept=".json"
|
||||
|
||||
@@ -50,7 +50,7 @@ const DocumentDisplay = ({
|
||||
</a>
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-x-2 mt-1 text-xs">
|
||||
<div className="px-1 py-0.5 bg-accent-background-hovered rounded flex">
|
||||
<div className="px-1 py-0.5 bg-hover rounded flex">
|
||||
<p className="mr-1 my-auto">Boost:</p>
|
||||
<ScoreSection
|
||||
documentId={document.document_id}
|
||||
@@ -77,7 +77,7 @@ const DocumentDisplay = ({
|
||||
});
|
||||
}
|
||||
}}
|
||||
className="px-1 py-0.5 bg-accent-background-hovered hover:bg-accent-background rounded flex cursor-pointer select-none"
|
||||
className="px-1 py-0.5 bg-hover hover:bg-hover-light rounded flex cursor-pointer select-none"
|
||||
>
|
||||
<div className="my-auto">
|
||||
{document.hidden ? (
|
||||
@@ -169,7 +169,7 @@ export function Explorer({
|
||||
<div>
|
||||
{popup}
|
||||
<div className="justify-center py-2">
|
||||
<div className="flex items-center w-full border-2 border-border rounded-lg px-4 py-2 focus-within:border-accent bg-background-search dark:bg-transparent">
|
||||
<div className="flex items-center w-full border-2 border-border rounded-lg px-4 py-2 focus-within:border-accent bg-background-search">
|
||||
<MagnifyingGlass />
|
||||
<textarea
|
||||
autoFocus
|
||||
@@ -221,7 +221,7 @@ export function Explorer({
|
||||
</div>
|
||||
)}
|
||||
{!query && (
|
||||
<div className="flex text-text-darker mt-3">
|
||||
<div className="flex text-emphasis mt-3">
|
||||
Search for a document above to modify its boost or hide it from
|
||||
searches.
|
||||
</div>
|
||||
|
||||
@@ -37,7 +37,7 @@ const IsVisibleSection = ({
|
||||
);
|
||||
onUpdate(response);
|
||||
}}
|
||||
className="flex text-error cursor-pointer hover:bg-accent-background-hovered py-1 px-2 w-fit rounded-full"
|
||||
className="flex text-error cursor-pointer hover:bg-hover py-1 px-2 w-fit rounded-full"
|
||||
>
|
||||
<div className="select-none">Hidden</div>
|
||||
<div className="ml-1 my-auto">
|
||||
@@ -53,7 +53,7 @@ const IsVisibleSection = ({
|
||||
);
|
||||
onUpdate(response);
|
||||
}}
|
||||
className="flex cursor-pointer hover:bg-accent-background-hovered py-1 px-2 w-fit rounded-full"
|
||||
className="flex cursor-pointer hover:bg-hover py-1 px-2 w-fit rounded-full"
|
||||
>
|
||||
<div className="my-auto select-none">Visible</div>
|
||||
<div className="ml-1 my-auto">
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user