Compare commits

..

1 Commits

Author SHA1 Message Date
pablodanswer
852c0d7a8c stashing 2025-02-06 16:58:14 -08:00
54 changed files with 407 additions and 1659 deletions

View File

@@ -94,19 +94,16 @@ jobs:
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
MULTI_TENANT=true \
AUTH_TYPE=cloud \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
DEV_MODE=true \
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack up -d
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker_multi_tenant
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
- name: Run Multi-Tenant Integration Tests
run: |
echo "Waiting for 3 minutes to ensure API server is ready..."
sleep 180
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
--name test-runner \
@@ -122,10 +119,6 @@ jobs:
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \
-e REQUIRE_EMAIL_VERIFICATION=false \
-e DISABLE_TELEMETRY=true \
-e IMAGE_TAG=test \
-e DEV_MODE=true \
onyxdotapp/onyx-integration:test \
/app/tests/integration/multitenant_tests
continue-on-error: true
@@ -133,17 +126,17 @@ jobs:
- name: Check multi-tenant test results
run: |
if [ ${{ steps.run_multitenant_tests.outcome }} == 'failure' ]; then
echo "Multi-tenant integration tests failed. Exiting with error."
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All multi-tenant integration tests passed successfully."
echo "All integration tests passed successfully."
fi
- name: Stop multi-tenant Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack down -v
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Start Docker containers
run: |
@@ -223,30 +216,27 @@ jobs:
echo "All integration tests passed successfully."
fi
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
- name: Dump API server logs
if: always()
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v

View File

@@ -5,6 +5,7 @@ Revises: 47e5bef3a1d7
Create Date: 2024-11-06 13:15:53.302644
"""
import logging
from typing import cast
from alembic import op
import sqlalchemy as sa
@@ -19,8 +20,13 @@ down_revision = "47e5bef3a1d7"
branch_labels: None = None
depends_on: None = None
# Configure logging
logger = logging.getLogger("alembic.runtime.migration")
logger.setLevel(logging.INFO)
def upgrade() -> None:
logger.info(f"{revision}: create_table: slack_bot")
# Create new slack_bot table
op.create_table(
"slack_bot",
@@ -57,6 +63,7 @@ def upgrade() -> None:
)
# Handle existing Slack bot tokens first
logger.info(f"{revision}: Checking for existing Slack bot.")
bot_token = None
app_token = None
first_row_id = None
@@ -64,12 +71,15 @@ def upgrade() -> None:
try:
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
except Exception:
logger.warning("No existing Slack bot tokens found.")
tokens = {}
bot_token = tokens.get("bot_token")
app_token = tokens.get("app_token")
if bot_token and app_token:
logger.info(f"{revision}: Found bot and app tokens.")
session = Session(bind=op.get_bind())
new_slack_bot = SlackBot(
name="Slack Bot (Migrated)",
@@ -160,9 +170,10 @@ def upgrade() -> None:
# Clean up old tokens if they existed
try:
if bot_token and app_token:
logger.info(f"{revision}: Removing old bot and app tokens.")
get_kv_store().delete("slack_bot_tokens_config_key")
except Exception:
pass
logger.warning("tried to delete tokens in dynamic config but failed")
# Rename the table
op.rename_table(
"slack_bot_config__standard_answer_category",
@@ -179,6 +190,8 @@ def upgrade() -> None:
# Drop the table with CASCADE to handle dependent objects
op.execute("DROP TABLE slack_bot_config CASCADE")
logger.info(f"{revision}: Migration complete.")
def downgrade() -> None:
# Recreate the old slack_bot_config table
@@ -260,7 +273,7 @@ def downgrade() -> None:
}
get_kv_store().store("slack_bot_tokens_config_key", tokens)
except Exception:
pass
logger.warning("Failed to save tokens back to KV store")
# Drop the new tables in reverse order
op.drop_table("slack_channel_config")

View File

@@ -64,7 +64,6 @@ async def _get_tenant_id_from_request(
try:
# Look up token data in Redis
token_data = await retrieve_auth_token_data_from_redis(request)
if not token_data:

View File

@@ -24,7 +24,6 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
@@ -86,8 +85,7 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)

View File

@@ -9,6 +9,7 @@ class CoreState(BaseModel):
This is the core state that is shared across all subgraphs.
"""
base_question: str = ""
log_messages: Annotated[list[str], add] = []
@@ -17,4 +18,4 @@ class SubgraphCoreState(BaseModel):
This is the core state that is shared across all subgraphs.
"""
log_messages: Annotated[list[str], add] = []
log_messages: Annotated[list[str], add]

View File

@@ -1,8 +1,8 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
@@ -12,39 +12,12 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
SubQuestionAnswerCheckUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import AgentLLMErrorType
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. The sub-answer will be treated as 'relevant'",
rate_limit="LLM Rate Limit Error. The sub-answer will be treated as 'relevant'",
general_error="General LLM Error. The sub-answer will be treated as 'relevant'",
)
def check_sub_answer(
@@ -80,46 +53,14 @@ def check_sub_answer(
graph_config = cast(GraphConfig, config["metadata"]["config"])
fast_llm = graph_config.tooling.fast_llm
agent_error: AgentErrorLoggingFormat | None = None
response: BaseMessage | None = None
try:
response = fast_llm.invoke(
response = list(
fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
)
)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - check sub answer")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - check sub answer")
if agent_error:
answer_quality = True
log_result = agent_error.error_result
else:
if response:
quality_str: str = cast(str, response.content)
answer_quality = binary_string_test(
text=quality_str, positive_value=AGENT_POSITIVE_VALUE_STR
)
else:
answer_quality = True
quality_str = "yes - because LLM error"
log_result = f"Answer quality: {quality_str}"
quality_str: str = merge_message_runs(response, chunk_separator="")[0].content
answer_quality = "yes" in quality_str.lower()
return SubQuestionAnswerCheckUpdate(
answer_quality=answer_quality,
@@ -128,7 +69,7 @@ def check_sub_answer(
graph_component="initial - generate individual sub answer",
node_name="check sub answer",
node_start_time=node_start_time,
result=log_result,
result=f"Answer quality: {quality_str}",
)
],
)

View File

@@ -16,20 +16,6 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
LLM_ANSWER_ERROR_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
@@ -44,20 +30,11 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. A sub-answer could not be constructed and the sub-question will be ignored.",
rate_limit="LLM Rate Limit Error. A sub-answer could not be constructed and the sub-question will be ignored.",
general_error="General LLM Error. A sub-answer could not be constructed and the sub-question will be ignored.",
)
def generate_sub_answer(
state: AnswerQuestionState,
@@ -80,8 +57,6 @@ def generate_sub_answer(
if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
cited_documents: list = []
log_results = "No documents retrieved"
write_custom_event(
"sub_answers",
AgentAnswerPiece(
@@ -104,67 +79,41 @@ def generate_sub_answer(
response: list[str | list[str | dict[str, Any]]] = []
dispatch_timings: list[float] = []
agent_error: AgentErrorLoggingFormat | None = None
try:
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
for message in fast_llm.stream(
prompt=msg,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
logger.error("LLM Timeout Error - generate sub answer")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
logger.error("LLM Rate Limit Error - generate sub answer")
response.append(content)
if agent_error:
answer_str = LLM_ANSWER_ERROR_MESSAGE
cited_documents = []
log_results = (
agent_error.error_result
or "Sub-answer generation failed due to LLM error"
)
answer_str = merge_message_runs(response, chunk_separator="")[0].content
logger.debug(
f"Average dispatch time: {sum(dispatch_timings) / len(dispatch_timings)}"
)
else:
answer_str = merge_message_runs(response, chunk_separator="")[0].content
answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
log_results = None
answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
@@ -182,7 +131,7 @@ def generate_sub_answer(
graph_component="initial - generate individual sub answer",
node_name="generate sub answer",
node_start_time=node_start_time,
result=log_results or "",
result="",
)
],
)

View File

@@ -42,8 +42,10 @@ class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):
class SubQuestionAnsweringInput(SubgraphCoreState):
question: str
question_id: str
question: str = ""
question_id: str = (
"" # 0_0 is original question, everything else is <level>_<question_num>.
)
# level 0 is original question and first decomposition, level 1 is follow up, etc
# question_num is a unique number per original question per level.

View File

@@ -26,18 +26,7 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
@@ -53,16 +42,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_ci
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
)
from onyx.context.search.models import InferenceSection
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
@@ -72,12 +57,6 @@ from onyx.prompts.agent_search import (
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. The initial answer could not be generated.",
rate_limit="LLM Rate Limit Error. The initial answer could not be generated.",
general_error="General LLM Error. The initial answer could not be generated.",
)
def generate_initial_answer(
state: SubQuestionRetrievalState,
@@ -245,82 +224,30 @@ def generate_initial_answer(
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
agent_error: AgentErrorLoggingFormat | None = None
try:
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
start_stream_token = datetime.now()
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate initial answer")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate initial answer")
if agent_error:
write_custom_event(
"initial_agent_answer",
StreamingError(
error=AGENT_LLM_TIMEOUT_MESSAGE,
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
return InitialAnswerUpdate(
initial_answer=None,
error=AgentErrorLoggingFormat(
error_message=agent_error.error_message or "An LLM error occurred",
error_type=agent_error.error_type,
error_result=agent_error.error_result,
),
initial_agent_stats=None,
generated_sub_questions=sub_questions,
agent_base_end_time=None,
agent_base_metrics=None,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="generate initial answer",
node_start_time=node_start_time,
result=agent_error.error_result or "An LLM error occurred",
)
],
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
logger.debug(
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"

View File

@@ -25,7 +25,7 @@ def validate_initial_answer(
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
)
verdict = True # not actually required as already streamed out. Refinement will do similar
verdict = True
return InitialAnswerQualityUpdate(
initial_answer_quality_eval=verdict,

View File

@@ -23,18 +23,6 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
@@ -45,11 +33,6 @@ from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
)
@@ -60,12 +43,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. Sub-questions could not be generated.",
rate_limit="LLM Rate Limit Error. Sub-questions could not be generated.",
general_error="General LLM Error. Sub-questions could not be generated.",
)
def decompose_orig_question(
state: SubQuestionRetrievalState,
@@ -135,35 +112,11 @@ def decompose_orig_question(
)
# dispatches custom events for subquestion tokens, adding in subquestion ids.
agent_error: AgentErrorLoggingFormat | None = None
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = dispatch_separated(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),
)
except LLMTimeoutError as e:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - decompose orig question")
raise e # fail loudly on this critical step
except LLMRateLimitError as e:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - decompose orig question")
raise e
streamed_tokens = dispatch_separated(
model.stream(msg),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),
)
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
@@ -172,19 +125,19 @@ def decompose_orig_question(
)
write_custom_event("stream_finished", stop_event, writer)
if agent_error:
initial_sub_questions: list[str] = []
log_result = agent_error.error_result
else:
deomposition_response = merge_content(*streamed_tokens)
deomposition_response = merge_content(*streamed_tokens)
list_of_subqs = cast(str, deomposition_response).split("\n")
# this call should only return strings. Commenting out for efficiency
# assert [type(tok) == str for tok in streamed_tokens]
initial_sub_questions = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
log_result = f"decomposed original question into {len(initial_sub_questions)} subquestions"
# use no-op cast() instead of str() which runs code
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
list_of_subqs = cast(str, deomposition_response).split("\n")
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
return InitialQuestionDecompositionUpdate(
initial_sub_questions=initial_sub_questions,
initial_sub_questions=decomp_list,
agent_start_time=agent_start_time,
agent_refined_start_time=None,
agent_refined_end_time=None,
@@ -198,7 +151,7 @@ def decompose_orig_question(
graph_component="initial - generate sub answers",
node_name="decompose original question",
node_start_time=node_start_time,
result=log_result,
result=f"decomposed original question into {len(decomp_list)} subquestions",
)
],
)

View File

@@ -252,7 +252,9 @@ if __name__ == "__main__":
db_session, primary_llm, fast_llm, search_request
)
inputs = MainInput(log_messages=[])
inputs = MainInput(
base_question=graph_config.inputs.search_request.query, log_messages=[]
)
for thing in compiled_graph.stream(
input=inputs,

View File

@@ -1,7 +1,6 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
@@ -11,37 +10,14 @@ from onyx.agents.agent_search.deep_search.main.states import (
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out, and the answers could not be compared.",
rate_limit="The LLM encountered a rate limit, and the answers could not be compared.",
general_error="The LLM encountered an error, and the answers could not be compared.",
)
def compare_answers(
@@ -64,46 +40,15 @@ def compare_answers(
msg = [HumanMessage(content=compare_answers_prompt)]
agent_error: AgentErrorLoggingFormat | None = None
# Get the rewritten queries in a defined format
model = graph_config.tooling.fast_llm
resp: BaseMessage | None = None
refined_answer_improvement: bool | None = None
# no need to stream this
try:
resp = model.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
)
resp = model.invoke(msg)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - compare answers")
# continue as True in this support step
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - compare answers")
# continue as True in this support step
if agent_error or resp is None:
refined_answer_improvement = True
if agent_error:
log_result = agent_error.error_result
else:
log_result = "An answer could not be generated."
else:
refined_answer_improvement = (
isinstance(resp.content, str) and "yes" in resp.content.lower()
)
log_result = f"Answer comparison: {refined_answer_improvement}"
refined_answer_improvement = (
isinstance(resp.content, str) and "yes" in resp.content.lower()
)
write_custom_event(
"refined_answer_improvement",
@@ -120,7 +65,7 @@ def compare_answers(
graph_component="main",
node_name="compare answers",
node_start_time=node_start_time,
result=log_result,
result=f"Answer comparison: {refined_answer_improvement}",
)
],
)

View File

@@ -21,18 +21,6 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
format_entity_term_extraction,
@@ -42,25 +30,10 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT,
)
from onyx.tools.models import ToolCallKickoff
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The sub-questions could not be generated.",
rate_limit="The LLM encountered a rate limit. The sub-questions could not be generated.",
general_error="The LLM encountered an error. The sub-questions could not be generated.",
)
def create_refined_sub_questions(
@@ -123,65 +96,29 @@ def create_refined_sub_questions(
# Grader
model = graph_config.tooling.fast_llm
agent_error: AgentErrorLoggingFormat | None = None
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = dispatch_separated(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),
)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - create refined sub questions")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - create refined sub questions")
if agent_error:
refined_sub_question_dict: dict[int, RefinementSubQuestion] = {}
log_result = agent_error.error_result
write_custom_event(
"refined_sub_question_creation_error",
StreamingError(
error="Your LLM was not able to create refined sub questions in time and timed out. Please try again.",
),
writer,
)
streamed_tokens = dispatch_separated(
model.stream(msg),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),
)
response = merge_content(*streamed_tokens)
if isinstance(response, str):
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
else:
response = merge_content(*streamed_tokens)
raise ValueError("LLM response is not a string")
if isinstance(response, str):
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
else:
raise ValueError("LLM response is not a string")
refined_sub_question_dict = {}
for sub_question_num, sub_question in enumerate(parsed_response):
refined_sub_question = RefinementSubQuestion(
sub_question=sub_question,
sub_question_id=make_question_id(1, sub_question_num + 1),
verified=False,
answered=False,
answer="",
)
refined_sub_question_dict = {}
for sub_question_num, sub_question in enumerate(parsed_response):
refined_sub_question = RefinementSubQuestion(
sub_question=sub_question,
sub_question_id=make_question_id(1, sub_question_num + 1),
verified=False,
answered=False,
answer="",
)
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
log_result = f"Created {len(refined_sub_question_dict)} refined sub questions"
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
return RefinedQuestionDecompositionUpdate(
refined_sub_questions=refined_sub_question_dict,
@@ -191,7 +128,7 @@ def create_refined_sub_questions(
graph_component="main",
node_name="create refined sub questions",
node_start_time=node_start_time,
result=log_result,
result=f"Created {len(refined_sub_question_dict)} refined sub questions",
)
],
)

View File

@@ -26,19 +26,6 @@ def decide_refinement_need(
decision = True # TODO: just for current testing purposes
if state.error:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=False,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="decide refinement need",
node_start_time=node_start_time,
result="Timeout Error",
)
],
)
log_messages = [
get_langgraph_node_log_string(
graph_component="main",

View File

@@ -21,9 +21,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
)
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
@@ -84,7 +81,6 @@ def extract_entities_terms(
# Grader
llm_response = fast_llm.invoke(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (

View File

@@ -11,6 +11,7 @@ from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RefinedAnswerUpdate,
@@ -22,18 +23,7 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import InferenceSection
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
@@ -53,14 +43,8 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
)
@@ -72,15 +56,6 @@ from onyx.prompts.agent_search import (
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The refined answer could not be generated.",
rate_limit="The LLM encountered a rate limit. The refined answer could not be generated.",
general_error="The LLM encountered an error. The refined answer could not be generated.",
)
def generate_refined_answer(
@@ -256,80 +231,28 @@ def generate_refined_answer(
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
agent_error: AgentErrorLoggingFormat | None = None
try:
for message in model.stream(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=1,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate refined answer")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate refined answer")
if agent_error:
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
StreamingError(
error=AGENT_LLM_TIMEOUT_MESSAGE,
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=1,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
return RefinedAnswerUpdate(
refined_answer=None,
refined_answer_quality=False, # TODO: replace this with the actual check value
refined_agent_stats=None,
agent_refined_end_time=None,
agent_refined_metrics=AgentRefinedMetrics(
refined_doc_boost_factor=0.0,
refined_question_boost_factor=0.0,
duration_s=None,
),
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate refined answer",
node_start_time=node_start_time,
result=agent_error.error_result or "An LLM error occurred",
)
],
)
end_stream_token = datetime.now()
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
streamed_tokens.append(content)
logger.debug(
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
@@ -343,6 +266,49 @@ def generate_refined_answer(
revision_question_efficiency=revision_question_efficiency,
)
logger.debug(f"\n\n---INITIAL ANSWER ---\n\n Answer:\n Agent: {initial_answer}")
logger.debug("-" * 10)
logger.debug(f"\n\n---REVISED AGENT ANSWER ---\n\n Answer:\n Agent: {answer}")
logger.debug("-" * 100)
if state.initial_agent_stats:
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio", "--"
)
initial_support_boost_factor = (
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
)
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
"num_verified_documents", "--"
)
initial_verified_docs_avg_score = (
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
)
initial_sub_questions_verified_docs = (
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
)
logger.debug("INITIAL AGENT STATS")
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
logger.debug(
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
)
logger.debug(
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
)
if refined_agent_stats:
logger.debug("-" * 10)
logger.debug("REFINED AGENT STATS")
logger.debug(
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
)
logger.debug(
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
)
agent_refined_end_time = datetime.now()
if state.agent_refined_start_time:
agent_refined_duration = (

View File

@@ -17,7 +17,6 @@ from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
@@ -77,7 +76,6 @@ class InitialAnswerUpdate(LoggerUpdate):
"""
initial_answer: str | None = None
error: AgentErrorLoggingFormat | None = None
initial_agent_stats: InitialAgentResultStats | None = None
generated_sub_questions: list[str] = []
agent_base_end_time: datetime | None = None
@@ -90,7 +88,6 @@ class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
"""
refined_answer: str | None = None
error: AgentErrorLoggingFormat | None = None
refined_agent_stats: RefinedAgentStats | None = None
refined_answer_quality: bool = False

View File

@@ -16,40 +16,14 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
QueryExpansionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
QUERY_REWRITING_PROMPT,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="Query rewriting failed due to LLM timeout - the original question will be used.",
rate_limit="Query rewriting failed due to LLM rate limit - the original question will be used.",
general_error="Query rewriting failed due to LLM error - the original question will be used.",
)
def expand_queries(
@@ -80,43 +54,13 @@ def expand_queries(
)
]
agent_error: AgentErrorLoggingFormat | None = None
llm_response_list: list[BaseMessage_Content] = []
llm_response_list = dispatch_separated(
llm.stream(prompt=msg), dispatch_subquery(level, question_num, writer)
)
try:
llm_response_list = dispatch_separated(
llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - expand queries")
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - expand queries")
# use subquestion as query if query generation fails
if agent_error:
llm_response = ""
rewritten_queries = [question]
log_result = agent_error.error_result
else:
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[
0
].content
rewritten_queries = llm_response.split("\n")
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
rewritten_queries = llm_response.split("\n")
return QueryExpansionUpdate(
expanded_queries=rewritten_queries,
@@ -125,7 +69,7 @@ def expand_queries(
graph_component="shared - expanded retrieval",
node_name="expand queries",
node_start_time=node_start_time,
result=log_result,
result=f"Number of expanded queries: {len(rewritten_queries)}",
)
],
)

View File

@@ -1,6 +1,5 @@
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables.config import RunnableConfig
@@ -11,41 +10,12 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
DocVerificationUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
DOCUMENT_VERIFICATION_PROMPT,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The document could not be verified. The document will be treated as 'relevant'",
rate_limit="The LLM encountered a rate limit. The document could not be verified. The document will be treated as 'relevant'",
general_error="The LLM encountered an error. The document could not be verified. The document will be treated as 'relevant'",
)
def verify_documents(
@@ -56,7 +26,7 @@ def verify_documents(
Args:
state (DocVerificationInput): The current state
config (RunnableConfig): Configuration containing AgentSearchConfig
config (RunnableConfig): Configuration containing ProSearchConfig
Updates:
verified_documents: list[InferenceSection]
@@ -81,42 +51,11 @@ def verify_documents(
)
]
agent_error: AgentErrorLoggingFormat | None = None
response: BaseMessage | None = None
response = fast_llm.invoke(msg)
try:
response = fast_llm.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
)
except LLMTimeoutError:
# In this case, we decide to continue and don't raise an error, as
# little harm in letting some docs through that are less relevant.
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - verify documents")
except LLMRateLimitError:
# In this case, we decide to continue and don't raise an error, as
# little harm in letting some docs through that are less relevant.
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - verify documents")
if agent_error or response is None:
verified_documents = [retrieved_document_to_verify]
else:
verified_documents = []
if isinstance(response.content, str) and binary_string_test(
text=response.content, positive_value=AGENT_POSITIVE_VALUE_STR
):
verified_documents.append(retrieved_document_to_verify)
verified_documents = []
if isinstance(response.content, str) and "yes" in response.content.lower():
verified_documents.append(retrieved_document_to_verify)
return DocVerificationUpdate(
verified_documents=verified_documents,

View File

@@ -21,13 +21,9 @@ from onyx.context.search.models import InferenceSection
class ExpandedRetrievalInput(SubgraphCoreState):
# exception from 'no default value'for LangGraph input states
# Here, sub_question_id default Nonoe implies usage for the
# original question. This is sometimes needed for nested sub-graphs
question: str = ""
base_search: bool = False
sub_question_id: str | None = None
question: str
base_search: bool
## Update/Return States
@@ -92,4 +88,4 @@ class DocVerificationInput(ExpandedRetrievalInput):
class RetrievalInput(ExpandedRetrievalInput):
query_to_retrieve: str
query_to_retrieve: str = ""

View File

@@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
main_graph_builder as main_graph_builder_a,
)
from onyx.agents.agent_search.deep_search.main.states import (
MainInput as MainInput,
MainInput as MainInput_a,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
@@ -21,7 +21,6 @@ from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
@@ -34,7 +33,6 @@ from onyx.llm.factory import get_default_llms
from onyx.tools.tool_runner import ToolCallKickoff
from onyx.utils.logger import setup_logger
logger = setup_logger()
_COMPILED_GRAPH: CompiledStateGraph | None = None
@@ -74,15 +72,13 @@ def _parse_agent_event(
return cast(AnswerPacket, event["data"])
elif event["name"] == "refined_answer_improvement":
return cast(RefinedAnswerImprovement, event["data"])
elif event["name"] == "refined_sub_question_creation_error":
return cast(StreamingError, event["data"])
return None
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
graph_input: BasicInput | MainInput,
graph_input: BasicInput | MainInput_a,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
for event in compiled_graph.stream(
@@ -96,7 +92,7 @@ def manage_sync_streaming(
def run_graph(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
input: BasicInput | MainInput,
input: BasicInput | MainInput_a,
) -> AnswerStream:
config.behavior.perform_initial_search_decomposition = (
INITIAL_SEARCH_DECOMPOSITION_ENABLED
@@ -127,7 +123,9 @@ def run_main_graph(
) -> AnswerStream:
compiled_graph = load_compiled_graph()
input = MainInput(log_messages=[])
input = MainInput_a(
base_question=config.inputs.search_request.query, log_messages=[]
)
# Agent search is not a Tool per se, but this is helpful for the frontend
yield ToolCallKickoff(
@@ -174,7 +172,9 @@ if __name__ == "__main__":
# search_request.persona = get_persona_by_id(1, None, db_session)
# config.perform_initial_search_path_decision = False
config.behavior.perform_initial_search_decomposition = True
input = MainInput(log_messages=[])
input = MainInput_a(
base_question=config.inputs.search_request.query, log_messages=[]
)
tool_responses: list = []
for output in run_graph(compiled_graph, config, input):

View File

@@ -150,17 +150,3 @@ def get_prompt_enrichment_components(
history=history,
date_str=date_str,
)
def binary_string_test(text: str, positive_value: str = "yes") -> bool:
"""
Tests if a string contains a positive value (case-insensitive).
Args:
text: The string to test
positive_value: The value to look for (defaults to "yes")
Returns:
True if the positive value is found in the text
"""
return positive_value.lower() in text.lower()

View File

@@ -1,17 +0,0 @@
from enum import Enum
AGENT_LLM_TIMEOUT_MESSAGE = "The agent timed out. Please try again."
AGENT_LLM_ERROR_MESSAGE = "The agent encountered an error. Please try again."
AGENT_LLM_RATELIMIT_MESSAGE = (
"The agent encountered a rate limit error. Please try again."
)
LLM_ANSWER_ERROR_MESSAGE = "The question was not answered due to an LLM error."
AGENT_POSITIVE_VALUE_STR = "yes"
AGENT_NEGATIVE_VALUE_STR = "no"
class AgentLLMErrorType(str, Enum):
TIMEOUT = "timeout"
RATE_LIMIT = "rate_limit"
GENERAL_ERROR = "general_error"

View File

@@ -1,5 +1,3 @@
from typing import Any
from pydantic import BaseModel
from onyx.agents.agent_search.deep_search.main.models import (
@@ -58,12 +56,6 @@ class InitialAgentResultStats(BaseModel):
agent_effectiveness: dict[str, float | int | None]
class AgentErrorLoggingFormat(BaseModel):
error_message: str
error_type: str
error_result: str | None = None
class RefinedAgentStats(BaseModel):
revision_doc_efficiency: float | None
revision_question_efficiency: float | None
@@ -134,12 +126,3 @@ class AgentPromptEnrichmentComponents(BaseModel):
persona_prompts: PersonaPromptExpressions
history: str
date_str: str
class LLMNodeErrorStrings(BaseModel):
timeout: str = "LLM Timeout Error"
rate_limit: str = "LLM Rate Limit Error"
general_error: str = "General LLM Error"
BaseMessage_Content = str | list[str | dict[str, Any]]

View File

@@ -20,7 +20,6 @@ from onyx.agents.agent_search.models import GraphInputs
from onyx.agents.agent_search.models import GraphPersistence
from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
@@ -35,9 +34,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
)
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
@@ -50,8 +46,6 @@ from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import Persona
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.llm.interfaces import LLM
from onyx.prompts.agent_search import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
@@ -71,9 +65,8 @@ from onyx.tools.tool_implementations.search.search_tool import (
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
logger = setup_logger()
BaseMessage_Content = str | list[str | dict[str, Any]]
# Post-processing
@@ -379,24 +372,8 @@ def summarize_history(
)
)
try:
history_response = llm.invoke(
history_context_prompt,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
)
except LLMTimeoutError:
logger.error("LLM Timeout Error - summarize history")
return (
history # this is what is done at this point anyway, so we default to this
)
except LLMRateLimitError:
logger.error("LLM Rate Limit Error - summarize history")
return (
history # this is what is done at this point anyway, so we default to this
)
history_response = llm.invoke(history_context_prompt)
assert isinstance(history_response.content, str)
return history_response.content

View File

@@ -13,21 +13,6 @@ AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = 30 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = 10 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = 3 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = 12 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = 8 # in seconds
#####
# Agent Configs
#####
@@ -92,76 +77,4 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
or AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH
) # 2000
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION
) # 25
AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
) # 3
AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION
) # 30
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION
) # 8
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
) # 12
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION
) # 25
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
) # 25
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
) # 8
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION
) # 6
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION
) # 1
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION
) # 4
AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
) # 8
GRAPH_VERSION_NAME: str = "a"

View File

@@ -263,11 +263,6 @@ class PostgresAdvisoryLocks(Enum):
class OnyxCeleryQueues:
# "celery" is the default queue defined by celery and also the queue
# we are running in the primary worker to run system tasks
# Tasks running in this queue should be designed specifically to run quickly
PRIMARY = "celery"
# Light queue
VESPA_METADATA_SYNC = "vespa_metadata_sync"
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"

View File

@@ -91,7 +91,6 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
f"&response_type=code"
f"&scope=read"
f"&state={state}"
f"&prompt=consent" # prompts user for access; allows choosing workspace
)
@classmethod

View File

@@ -50,18 +50,6 @@ litellm.telemetry = False
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
class LLMTimeoutError(Exception):
"""
Exception raised when an LLM call times out.
"""
class LLMRateLimitError(Exception):
"""
Exception raised when an LLM call is rate limited.
"""
def _base_msg_to_role(msg: BaseMessage) -> str:
if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk):
return "user"
@@ -392,7 +380,6 @@ class DefaultMultiLLM(LLM):
tool_choice: ToolChoiceOptions | None,
stream: bool,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
# to a dict representation
@@ -418,7 +405,7 @@ class DefaultMultiLLM(LLM):
stream=stream,
# model params
temperature=0,
timeout=timeout_override or self._timeout,
timeout=self._timeout,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
@@ -437,12 +424,6 @@ class DefaultMultiLLM(LLM):
except Exception as e:
self._record_error(processed_prompt, e)
# for break pointing
if isinstance(e, litellm.Timeout):
raise LLMTimeoutError(e)
elif isinstance(e, litellm.RateLimitError):
raise LLMRateLimitError(e)
raise e
@property
@@ -463,7 +444,6 @@ class DefaultMultiLLM(LLM):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> BaseMessage:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
@@ -471,12 +451,7 @@ class DefaultMultiLLM(LLM):
response = cast(
litellm.ModelResponse,
self._completion(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=False,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
prompt, tools, tool_choice, False, structured_response_format
),
)
choice = response.choices[0]
@@ -494,31 +469,19 @@ class DefaultMultiLLM(LLM):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> Iterator[BaseMessage]:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
if DISABLE_LITELLM_STREAMING:
yield self.invoke(
prompt,
tools,
tool_choice,
structured_response_format,
timeout_override,
)
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
return
output = None
response = cast(
litellm.CustomStreamWrapper,
self._completion(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=True,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
prompt, tools, tool_choice, True, structured_response_format
),
)
try:

View File

@@ -81,7 +81,6 @@ class CustomModelServer(LLM):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> BaseMessage:
return self._execute(prompt)
@@ -91,6 +90,5 @@ class CustomModelServer(LLM):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> Iterator[BaseMessage]:
yield self._execute(prompt)

View File

@@ -90,13 +90,12 @@ class LLM(abc.ABC):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._invoke_implementation(
prompt, tools, tool_choice, structured_response_format, timeout_override
prompt, tools, tool_choice, structured_response_format
)
@abc.abstractmethod
@@ -106,7 +105,6 @@ class LLM(abc.ABC):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> BaseMessage:
raise NotImplementedError
@@ -116,13 +114,12 @@ class LLM(abc.ABC):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
messages = self._stream_implementation(
prompt, tools, tool_choice, structured_response_format, timeout_override
prompt, tools, tool_choice, structured_response_format
)
tokens = []
@@ -141,6 +138,5 @@ class LLM(abc.ABC):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError

View File

@@ -5,6 +5,8 @@ UNKNOWN_ANSWER = "I do not have enough information to answer this question."
NO_RECOVERED_DOCS = "No relevant information recovered"
YES = "yes"
NO = "no"
# Framing/Support/Template Prompts
HISTORY_FRAMING_PROMPT = f"""
For more context, here is the history of the conversation so far that preceded this question:

View File

@@ -22,8 +22,6 @@ from onyx.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.connector_credential_pair import add_credential_to_connector
from onyx.db.connector_credential_pair import (
get_connector_credential_pair_from_id_for_user,
@@ -230,13 +228,6 @@ def update_cc_pair_status(
db_session.commit()
# this speeds up the start of indexing by firing the check immediately
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
kwargs=dict(tenant_id=tenant_id),
priority=OnyxCeleryPriority.HIGH,
)
return JSONResponse(
status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)}
)
@@ -549,14 +540,7 @@ def associate_credential_to_connector(
metadata: ConnectorCredentialPairMetadata,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> StatusResponse[int]:
"""NOTE(rkuo): internally discussed and the consensus is this endpoint
and create_connector_with_mock_credential should be combined.
The intent of this endpoint is to handle connectors that actually need credentials.
"""
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_object_creation_for_user", None
)(
@@ -579,18 +563,6 @@ def associate_credential_to_connector(
groups=metadata.groups,
)
# trigger indexing immediately
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
)
logger.info(
f"associate_credential_to_connector - running check_for_indexing: "
f"cc_pair={response.data}"
)
return response
except IntegrityError as e:
logger.error(f"IntegrityError: {e}")

View File

@@ -804,14 +804,6 @@ def create_connector_with_mock_credential(
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> StatusResponse:
"""NOTE(rkuo): internally discussed and the consensus is this endpoint
and associate_credential_to_connector should be combined.
The intent of this endpoint is to handle connectors that don't need credentials,
AKA web, file, etc ... but there isn't any reason a single endpoint couldn't
server this purpose.
"""
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_object_creation_for_user", None
)(
@@ -849,18 +841,6 @@ def create_connector_with_mock_credential(
groups=connector_data.groups,
)
# trigger indexing immediately
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
)
logger.info(
f"create_connector_with_mock_credential - running check_for_indexing: "
f"cc_pair={response.data}"
)
create_milestone_and_report(
user=user,
distinct_id=user.email if user else tenant_id or "N/A",
@@ -1025,8 +1005,6 @@ def connector_run_once(
kwargs={"tenant_id": tenant_id},
)
logger.info("connector_run_once - running check_for_indexing")
msg = f"Marked {num_triggers} index attempts with indexing triggers."
return StatusResponse(
success=True,

View File

@@ -179,10 +179,12 @@ def oauth_callback(
db_session=db_session,
)
# TODO: use a library for url handling
sep = "&" if "?" in desired_return_url else "?"
return CallbackResponse(
redirect_url=f"{desired_return_url}{sep}credentialId={credential.id}"
redirect_url=(
f"{desired_return_url}?credentialId={credential.id}"
if "?" not in desired_return_url
else f"{desired_return_url}&credentialId={credential.id}"
)
)

View File

@@ -6,15 +6,11 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.document_set import check_document_sets_are_public
from onyx.db.document_set import fetch_all_document_sets_for_user
from onyx.db.document_set import insert_document_set
from onyx.db.document_set import mark_document_set_as_to_be_deleted
from onyx.db.document_set import update_document_set
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.server.features.document_set.models import CheckDocSetPublicRequest
@@ -33,7 +29,6 @@ def create_document_set(
document_set_creation_request: DocumentSetCreationRequest,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> int:
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_object_creation_for_user", None
@@ -51,13 +46,6 @@ def create_document_set(
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,
)
return document_set_db_model.id
@@ -66,7 +54,6 @@ def patch_document_set(
document_set_update_request: DocumentSetUpdateRequest,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> None:
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_object_creation_for_user", None
@@ -85,19 +72,12 @@ def patch_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,
)
@router.delete("/admin/document-set/{document_set_id}")
def delete_document_set(
document_set_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> None:
try:
mark_document_set_as_to_be_deleted(
@@ -108,12 +88,6 @@ def delete_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,
)
"""Endpoints for non-admins"""

View File

@@ -197,11 +197,6 @@ def create_deletion_attempt_for_connector_id(
kwargs={"tenant_id": tenant_id},
)
logger.info(
f"create_deletion_attempt_for_connector_id - running check_for_connector_deletion: "
f"cc_pair={cc_pair.id}"
)
if cc_pair.connector.source == DocumentSource.FILE:
connector = cc_pair.connector
file_store = get_default_file_store(db_session)

View File

@@ -34,7 +34,6 @@ from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import optional_user
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
@@ -287,7 +286,7 @@ def bulk_invite_users(
detail=f"Invalid email address: {email} - {str(e)}",
)
if MULTI_TENANT and not DEV_MODE:
if MULTI_TENANT:
try:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning", "add_users_to_tenant", None

View File

@@ -70,8 +70,8 @@ COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
# Set up application files
COPY ./onyx /app/onyx
COPY ./shared_configs /app/shared_configs
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic /app/alembic
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic.ini /app/alembic.ini
COPY ./pytest.ini /app/pytest.ini
COPY supervisord.conf /usr/etc/supervisord.conf

View File

@@ -24,6 +24,35 @@ def generate_auth_token() -> str:
class TenantManager:
@staticmethod
def create(
tenant_id: str | None = None,
initial_admin_email: str | None = None,
referral_source: str | None = None,
) -> dict[str, str]:
body = {
"tenant_id": tenant_id,
"initial_admin_email": initial_admin_email,
"referral_source": referral_source,
}
token = generate_auth_token()
headers = {
"Authorization": f"Bearer {token}",
"X-API-KEY": "",
"Content-Type": "application/json",
}
response = requests.post(
url=f"{API_SERVER_URL}/tenants/create",
json=body,
headers=headers,
)
response.raise_for_status()
return response.json()
@staticmethod
def get_all_users(
user_performing_action: DATestUser | None = None,

View File

@@ -92,7 +92,6 @@ class UserManager:
# Set cookies in the headers
test_user.headers["Cookie"] = f"fastapiusersauth={session_cookie}; "
test_user.cookies = {"fastapiusersauth": session_cookie}
return test_user
@staticmethod
@@ -103,7 +102,6 @@ class UserManager:
response = requests.get(
url=f"{API_SERVER_URL}/me",
headers=user_to_verify.headers,
cookies=user_to_verify.cookies,
)
if user_to_verify.is_active is False:

View File

@@ -242,18 +242,6 @@ def reset_postgres_multitenant() -> None:
schema_name = schema[0]
cur.execute(f'DROP SCHEMA "{schema_name}" CASCADE')
# Drop tables in the public schema
cur.execute(
"""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
"""
)
public_tables = cur.fetchall()
for table in public_tables:
table_name = table[0]
cur.execute(f'DROP TABLE IF EXISTS public."{table_name}" CASCADE')
cur.close()
conn.close()

View File

@@ -44,7 +44,6 @@ class DATestUser(BaseModel):
headers: dict
role: UserRole
is_active: bool
cookies: dict = {}
class DATestPersonaLabel(BaseModel):

View File

@@ -4,6 +4,7 @@ from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.tenant import TenantManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestCCPair
@@ -12,28 +13,25 @@ from tests.integration.common_utils.test_models import DATestUser
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
# Creating an admin user (first user created is automatically an admin and also proviions the tenant
admin_user1: DATestUser = UserManager.create(
email="admin@onyx-test.com",
)
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
# Create Tenant 1 and its Admin User
TenantManager.create("tenant_dev1", "test1@test.com", "Data Plane Registration")
test_user1: DATestUser = UserManager.create(name="test1", email="test1@test.com")
assert UserManager.is_role(test_user1, UserRole.ADMIN)
# Create Tenant 2 and its Admin User
admin_user2: DATestUser = UserManager.create(
email="admin2@onyx-test.com",
)
assert UserManager.is_role(admin_user2, UserRole.ADMIN)
TenantManager.create("tenant_dev2", "test2@test.com", "Data Plane Registration")
test_user2: DATestUser = UserManager.create(name="test2", email="test2@test.com")
assert UserManager.is_role(test_user2, UserRole.ADMIN)
# Create connectors for Tenant 1
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user1,
user_performing_action=test_user1,
)
api_key_1: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user1,
user_performing_action=test_user1,
)
api_key_1.headers.update(admin_user1.headers)
LLMProviderManager.create(user_performing_action=admin_user1)
api_key_1.headers.update(test_user1.headers)
LLMProviderManager.create(user_performing_action=test_user1)
# Seed documents for Tenant 1
cc_pair_1.documents = []
@@ -51,13 +49,13 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
# Create connectors for Tenant 2
cc_pair_2: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user2,
user_performing_action=test_user2,
)
api_key_2: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user2,
user_performing_action=test_user2,
)
api_key_2.headers.update(admin_user2.headers)
LLMProviderManager.create(user_performing_action=admin_user2)
api_key_2.headers.update(test_user2.headers)
LLMProviderManager.create(user_performing_action=test_user2)
# Seed documents for Tenant 2
cc_pair_2.documents = []
@@ -78,17 +76,17 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
# Create chat sessions for each user
chat_session1: DATestChatSession = ChatSessionManager.create(
user_performing_action=admin_user1
user_performing_action=test_user1
)
chat_session2: DATestChatSession = ChatSessionManager.create(
user_performing_action=admin_user2
user_performing_action=test_user2
)
# User 1 sends a message and gets a response
response1 = ChatSessionManager.send_message(
chat_session_id=chat_session1.id,
message="What is in Tenant 1's documents?",
user_performing_action=admin_user1,
user_performing_action=test_user1,
)
# Assert that the search tool was used
assert response1.tool_name == "run_search"
@@ -102,16 +100,14 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
), "Tenant 2 document IDs should not be in the response"
# Assert that the contents are correct
assert any(
doc["content"] == "Tenant 1 Document Content"
for doc in response1.tool_result or []
), "Tenant 1 Document Content not found in any document"
for doc in response1.tool_result or []:
assert doc["content"] == "Tenant 1 Document Content"
# User 2 sends a message and gets a response
response2 = ChatSessionManager.send_message(
chat_session_id=chat_session2.id,
message="What is in Tenant 2's documents?",
user_performing_action=admin_user2,
user_performing_action=test_user2,
)
# Assert that the search tool was used
assert response2.tool_name == "run_search"
@@ -123,18 +119,15 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
assert not response_doc_ids.intersection(
tenant1_doc_ids
), "Tenant 1 document IDs should not be in the response"
# Assert that the contents are correct
assert any(
doc["content"] == "Tenant 2 Document Content"
for doc in response2.tool_result or []
), "Tenant 2 Document Content not found in any document"
for doc in response2.tool_result or []:
assert doc["content"] == "Tenant 2 Document Content"
# User 1 tries to access Tenant 2's documents
response_cross = ChatSessionManager.send_message(
chat_session_id=chat_session1.id,
message="What is in Tenant 2's documents?",
user_performing_action=admin_user1,
user_performing_action=test_user1,
)
# Assert that the search tool was used
assert response_cross.tool_name == "run_search"
@@ -147,7 +140,7 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
response_cross2 = ChatSessionManager.send_message(
chat_session_id=chat_session2.id,
message="What is in Tenant 1's documents?",
user_performing_action=admin_user2,
user_performing_action=test_user2,
)
# Assert that the search tool was used
assert response_cross2.tool_name == "run_search"

View File

@@ -4,12 +4,14 @@ from onyx.db.models import UserRole
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.tenant import TenantManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
# Test flow from creating tenant to registering as a user
def test_tenant_creation(reset_multitenant: None) -> None:
TenantManager.create("tenant_dev", "test@test.com", "Data Plane Registration")
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
assert UserManager.is_role(test_user, UserRole.ADMIN)

View File

@@ -1,23 +1,23 @@
import time
from datetime import datetime
from onyx.db.models import IndexingStatus
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestIndexAttempt
from tests.integration.common_utils.test_models import DATestUser
def _verify_index_attempt_pagination(
cc_pair_id: int,
index_attempt_ids: list[int],
index_attempts: list[DATestIndexAttempt],
page_size: int = 5,
user_performing_action: DATestUser | None = None,
) -> None:
retrieved_attempts: list[int] = []
last_time_started = None # Track the last time_started seen
for i in range(0, len(index_attempt_ids), page_size):
for i in range(0, len(index_attempts), page_size):
paginated_result = IndexAttemptManager.get_index_attempt_page(
cc_pair_id=cc_pair_id,
page=(i // page_size),
@@ -26,9 +26,9 @@ def _verify_index_attempt_pagination(
)
# Verify that the total items is equal to the length of the index attempts list
assert paginated_result.total_items == len(index_attempt_ids)
assert paginated_result.total_items == len(index_attempts)
# Verify that the number of items in the page is equal to the page size
assert len(paginated_result.items) == min(page_size, len(index_attempt_ids) - i)
assert len(paginated_result.items) == min(page_size, len(index_attempts) - i)
# Verify time ordering within the page (descending order)
for attempt in paginated_result.items:
@@ -42,7 +42,7 @@ def _verify_index_attempt_pagination(
retrieved_attempts.extend([attempt.id for attempt in paginated_result.items])
# Create a set of all the expected index attempt IDs
all_expected_attempts = set(index_attempt_ids)
all_expected_attempts = set(attempt.id for attempt in index_attempts)
# Create a set of all the retrieved index attempt IDs
all_retrieved_attempts = set(retrieved_attempts)
@@ -51,9 +51,6 @@ def _verify_index_attempt_pagination(
def test_index_attempt_pagination(reset: None) -> None:
MAX_WAIT = 60
all_attempt_ids: list[int] = []
# Create an admin user to perform actions
user_performing_action: DATestUser = UserManager.create(
name="admin_performing_action",
@@ -65,49 +62,20 @@ def test_index_attempt_pagination(reset: None) -> None:
user_performing_action=user_performing_action,
)
# Creating a CC pair will create an index attempt as well. wait for it.
start = time.monotonic()
while True:
paginated_result = IndexAttemptManager.get_index_attempt_page(
cc_pair_id=cc_pair.id,
page=0,
page_size=5,
user_performing_action=user_performing_action,
)
if paginated_result.total_items == 1:
all_attempt_ids.append(paginated_result.items[0].id)
print("Initial index attempt from cc_pair creation detected. Continuing...")
break
elapsed = time.monotonic() - start
if elapsed > MAX_WAIT:
raise TimeoutError(
f"Initial index attempt: Not detected within {MAX_WAIT} seconds."
)
print(
f"Waiting for initial index attempt: elapsed={elapsed:.2f} timeout={MAX_WAIT}"
)
time.sleep(1)
# Create 299 successful index attempts (for 300 total)
# Create 300 successful index attempts
base_time = datetime.now()
generated_attempts = IndexAttemptManager.create_test_index_attempts(
num_attempts=299,
all_attempts = IndexAttemptManager.create_test_index_attempts(
num_attempts=300,
cc_pair_id=cc_pair.id,
status=IndexingStatus.SUCCESS,
base_time=base_time,
)
for attempt in generated_attempts:
all_attempt_ids.append(attempt.id)
# Verify basic pagination with different page sizes
print("Verifying basic pagination with page size 5")
_verify_index_attempt_pagination(
cc_pair_id=cc_pair.id,
index_attempt_ids=all_attempt_ids,
index_attempts=all_attempts,
page_size=5,
user_performing_action=user_performing_action,
)
@@ -116,7 +84,7 @@ def test_index_attempt_pagination(reset: None) -> None:
print("Verifying pagination with page size 100")
_verify_index_attempt_pagination(
cc_pair_id=cc_pair.id,
index_attempt_ids=all_attempt_ids,
index_attempts=all_attempts,
page_size=100,
user_performing_action=user_performing_action,
)

View File

@@ -1,423 +0,0 @@
services:
api_server:
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
build:
context: ../../backend
dockerfile: Dockerfile
command: >
/bin/sh -c "
alembic -n schema_private upgrade head &&
echo \"Starting Onyx Api Server\" &&
uvicorn onyx.main:app --host 0.0.0.0 --port 8080"
depends_on:
- relational_db
- index
- cache
- inference_model_server
restart: always
ports:
- "8080:8080"
environment:
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
- MULTI_TENANT=true
- LOG_LEVEL=DEBUG
- AUTH_TYPE=cloud
- REQUIRE_EMAIL_VERIFICATION=false
- DISABLE_TELEMETRY=true
- IMAGE_TAG=test
- DEV_MODE=true
# Auth Settings
- SESSION_EXPIRE_TIME_SECONDS=${SESSION_EXPIRE_TIME_SECONDS:-}
- ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-}
- VALID_EMAIL_DOMAINS=${VALID_EMAIL_DOMAINS:-}
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
- SMTP_SERVER=${SMTP_SERVER:-}
- SMTP_PORT=${SMTP_PORT:-587}
- SMTP_USER=${SMTP_USER:-}
- SMTP_PASS=${SMTP_PASS:-}
- ENABLE_EMAIL_INVITES=${ENABLE_EMAIL_INVITES:-}
- EMAIL_FROM=${EMAIL_FROM:-}
- OAUTH_CLIENT_ID=${OAUTH_CLIENT_ID:-}
- OAUTH_CLIENT_SECRET=${OAUTH_CLIENT_SECRET:-}
- OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-}
- TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-}
- CORS_ALLOWED_ORIGIN=${CORS_ALLOWED_ORIGIN:-}
# Gen AI Settings
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
- QA_TIMEOUT=${QA_TIMEOUT:-}
- MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-}
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
- DISABLE_LLM_QUERY_REPHRASE=${DISABLE_LLM_QUERY_REPHRASE:-}
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
- DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-}
- LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-}
- BING_API_KEY=${BING_API_KEY:-}
- DISABLE_LLM_DOC_RELEVANCE=${DISABLE_LLM_DOC_RELEVANCE:-}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- TOKEN_BUDGET_GLOBALLY_ENABLED=${TOKEN_BUDGET_GLOBALLY_ENABLED:-}
# Query Options
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-}
- HYBRID_ALPHA=${HYBRID_ALPHA:-}
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
- LANGUAGE_HINT=${LANGUAGE_HINT:-}
- LANGUAGE_CHAT_NAMING_HINT=${LANGUAGE_CHAT_NAMING_HINT:-}
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
# Other services
- POSTGRES_HOST=relational_db
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
- VESPA_HOST=index
- REDIS_HOST=cache
- WEB_DOMAIN=${WEB_DOMAIN:-}
# Don't change the NLP model configs unless you know what you're doing
- EMBEDDING_BATCH_SIZE=${EMBEDDING_BATCH_SIZE:-}
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
- DISABLE_RERANK_FOR_STREAMING=${DISABLE_RERANK_FOR_STREAMING:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
- LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-}
- LOG_DANSWER_MODEL_INTERACTIONS=${LOG_DANSWER_MODEL_INTERACTIONS:-}
- LOG_INDIVIDUAL_MODEL_TOKENS=${LOG_INDIVIDUAL_MODEL_TOKENS:-}
- LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-}
- LOG_ENDPOINT_LATENCY=${LOG_ENDPOINT_LATENCY:-}
- LOG_POSTGRES_LATENCY=${LOG_POSTGRES_LATENCY:-}
- LOG_POSTGRES_CONN_COUNTS=${LOG_POSTGRES_CONN_COUNTS:-}
- CELERY_BROKER_POOL_LIMIT=${CELERY_BROKER_POOL_LIMIT:-}
- LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS=${LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS:-}
# Egnyte OAuth Configs
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
# Linear OAuth Configs
- LINEAR_CLIENT_ID=${LINEAR_CLIENT_ID:-}
- LINEAR_CLIENT_SECRET=${LINEAR_CLIENT_SECRET:-}
# Analytics Configs
- SENTRY_DSN=${SENTRY_DSN:-}
# Chat Configs
- HARD_DELETE_CHATS=${HARD_DELETE_CHATS:-}
# Enables the use of bedrock models or IAM Auth
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
- AWS_REGION_NAME=${AWS_REGION_NAME:-}
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
# Seeding configuration
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
background:
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
build:
context: ../../backend
dockerfile: Dockerfile
command: >
/bin/sh -c "
if [ -f /etc/ssl/certs/custom-ca.crt ]; then
update-ca-certificates;
fi &&
/usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf"
depends_on:
- relational_db
- index
- cache
- inference_model_server
- indexing_model_server
restart: always
environment:
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
- MULTI_TENANT=true
- LOG_LEVEL=DEBUG
- AUTH_TYPE=cloud
- REQUIRE_EMAIL_VERIFICATION=false
- DISABLE_TELEMETRY=true
- IMAGE_TAG=test
- ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-}
- JWT_PUBLIC_KEY_URL=${JWT_PUBLIC_KEY_URL:-}
# Gen AI Settings (Needed by OnyxBot)
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
- QA_TIMEOUT=${QA_TIMEOUT:-}
- MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-}
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
- DISABLE_LLM_QUERY_REPHRASE=${DISABLE_LLM_QUERY_REPHRASE:-}
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
- GENERATIVE_MODEL_ACCESS_CHECK_FREQ=${GENERATIVE_MODEL_ACCESS_CHECK_FREQ:-}
- DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-}
- LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- BING_API_KEY=${BING_API_KEY:-}
# Query Options
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-}
- HYBRID_ALPHA=${HYBRID_ALPHA:-}
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
- LANGUAGE_HINT=${LANGUAGE_HINT:-}
- LANGUAGE_CHAT_NAMING_HINT=${LANGUAGE_CHAT_NAMING_HINT:-}
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
# Other Services
- POSTGRES_HOST=relational_db
- POSTGRES_USER=${POSTGRES_USER:-}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
- POSTGRES_DB=${POSTGRES_DB:-}
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
- VESPA_HOST=index
- REDIS_HOST=cache
- WEB_DOMAIN=${WEB_DOMAIN:-}
# Don't change the NLP model configs unless you know what you're doing
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
# Indexing Configs
- VESPA_SEARCHER_THREADS=${VESPA_SEARCHER_THREADS:-}
- NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-}
- ENABLED_CONNECTOR_TYPES=${ENABLED_CONNECTOR_TYPES:-}
- DISABLE_INDEX_UPDATE_ON_SWAP=${DISABLE_INDEX_UPDATE_ON_SWAP:-}
- DASK_JOB_CLIENT_ENABLED=${DASK_JOB_CLIENT_ENABLED:-}
- CONTINUE_ON_CONNECTOR_FAILURE=${CONTINUE_ON_CONNECTOR_FAILURE:-}
- EXPERIMENTAL_CHECKPOINTING_ENABLED=${EXPERIMENTAL_CHECKPOINTING_ENABLED:-}
- CONFLUENCE_CONNECTOR_LABELS_TO_SKIP=${CONFLUENCE_CONNECTOR_LABELS_TO_SKIP:-}
- JIRA_CONNECTOR_LABELS_TO_SKIP=${JIRA_CONNECTOR_LABELS_TO_SKIP:-}
- WEB_CONNECTOR_VALIDATE_URLS=${WEB_CONNECTOR_VALIDATE_URLS:-}
- JIRA_API_VERSION=${JIRA_API_VERSION:-}
- GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-}
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
- GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-}
- MAX_DOCUMENT_CHARS=${MAX_DOCUMENT_CHARS:-}
- MAX_FILE_SIZE_BYTES=${MAX_FILE_SIZE_BYTES:-}
# Egnyte OAuth Configs
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
# Lienar OAuth Configs
- LINEAR_CLIENT_ID=${LINEAR_CLIENT_ID:-}
- LINEAR_CLIENT_SECRET=${LINEAR_CLIENT_SECRET:-}
# Celery Configs (defaults are set in the supervisord.conf file.
# prefer doing that to have one source of defaults)
- CELERY_WORKER_INDEXING_CONCURRENCY=${CELERY_WORKER_INDEXING_CONCURRENCY:-}
- CELERY_WORKER_LIGHT_CONCURRENCY=${CELERY_WORKER_LIGHT_CONCURRENCY:-}
- CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER=${CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER:-}
# Onyx SlackBot Configs
- DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER=${DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER:-}
- DANSWER_BOT_FEEDBACK_VISIBILITY=${DANSWER_BOT_FEEDBACK_VISIBILITY:-}
- DANSWER_BOT_DISPLAY_ERROR_MSGS=${DANSWER_BOT_DISPLAY_ERROR_MSGS:-}
- DANSWER_BOT_RESPOND_EVERY_CHANNEL=${DANSWER_BOT_RESPOND_EVERY_CHANNEL:-}
- DANSWER_BOT_DISABLE_COT=${DANSWER_BOT_DISABLE_COT:-} # Currently unused
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
- DANSWER_BOT_MAX_QPM=${DANSWER_BOT_MAX_QPM:-}
- DANSWER_BOT_MAX_WAIT_TIME=${DANSWER_BOT_MAX_WAIT_TIME:-}
# Logging
# Leave this on pretty please? Nothing sensitive is collected!
# https://docs.onyx.app/more/telemetry
- DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-}
- LOG_LEVEL=${LOG_LEVEL:-info} # Set to debug to get more fine-grained logs
- LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-} # LiteLLM Verbose Logging
# Log all of Onyx prompts and interactions with the LLM
- LOG_DANSWER_MODEL_INTERACTIONS=${LOG_DANSWER_MODEL_INTERACTIONS:-}
- LOG_INDIVIDUAL_MODEL_TOKENS=${LOG_INDIVIDUAL_MODEL_TOKENS:-}
- LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-}
# Analytics Configs
- SENTRY_DSN=${SENTRY_DSN:-}
# Enterprise Edition stuff
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
- AWS_REGION_NAME=${AWS_REGION_NAME:-}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# Uncomment the following lines if you need to include a custom CA certificate
# This section enables the use of a custom CA certificate
# If present, the custom CA certificate is mounted as a volume
# The container checks for its existence and updates the system's CA certificates
# This allows for secure communication with services using custom SSL certificates
# Optional volume mount for CA certificate
# volumes:
# # Maps to the CA_CERT_PATH environment variable in the Dockerfile
# - ${CA_CERT_PATH:-./custom-ca.crt}:/etc/ssl/certs/custom-ca.crt:ro
web_server:
image: onyxdotapp/onyx-web-server:${IMAGE_TAG:-latest}
build:
context: ../../web
dockerfile: Dockerfile
args:
- NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING:-false}
- NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA=${NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA:-false}
- NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
- NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN:-}
- NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED:-}
# Enterprise Edition only
- NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-}
# DO NOT TURN ON unless you have EXPLICIT PERMISSION from Onyx.
- NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED:-false}
depends_on:
- api_server
restart: always
environment:
- INTERNAL_URL=http://api_server:8080
- WEB_DOMAIN=${WEB_DOMAIN:-}
- THEME_IS_DARK=${THEME_IS_DARK:-}
- DISABLE_LLM_DOC_RELEVANCE=${DISABLE_LLM_DOC_RELEVANCE:-}
# Enterprise Edition only
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
- NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL:-}
inference_model_server:
image: onyxdotapp/onyx-model-server:${IMAGE_TAG:-latest}
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
# Analytics Configs
- SENTRY_DSN=${SENTRY_DSN:-}
volumes:
# Not necessary, this is just to reduce download time during startup
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
indexing_model_server:
image: onyxdotapp/onyx-model-server:${IMAGE_TAG:-latest}
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- INDEX_BATCH_SIZE=${INDEX_BATCH_SIZE:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
- INDEXING_ONLY=True
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
- CLIENT_EMBEDDING_TIMEOUT=${CLIENT_EMBEDDING_TIMEOUT:-}
# Analytics Configs
- SENTRY_DSN=${SENTRY_DSN:-}
volumes:
# Not necessary, this is just to reduce download time during startup
- indexing_huggingface_model_cache:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
relational_db:
image: postgres:15.2-alpine
command: -c 'max_connections=250'
restart: always
environment:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
volumes:
- db_volume:/var/lib/postgresql/data
# This container name cannot have an underscore in it due to Vespa expectations of the URL
index:
image: vespaengine/vespa:8.277.17
restart: always
ports:
- "19071:19071"
- "8081:8081"
volumes:
- vespa_volume:/opt/vespa/var
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
nginx:
image: nginx:1.23.4-alpine
restart: always
# nginx will immediately crash with `nginx: [emerg] host not found in upstream`
# if api_server / web_server are not up
depends_on:
- api_server
- web_server
environment:
- DOMAIN=localhost
ports:
- "80:80"
- "3000:80" # allow for localhost:3000 usage, since that is the norm
volumes:
- ../data/nginx:/etc/nginx/conf.d
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# The specified script waits for the api_server to start up.
# Without this we've seen issues where nginx shows no error logs but
# does not recieve any traffic
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
# in order to make this work on both Unix-like systems and windows
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev"
cache:
image: redis:7.4-alpine
restart: always
ports:
- "6379:6379"
# docker silently mounts /data even without an explicit volume mount, which enables
# persistence. explicitly setting save and appendonly forces ephemeral behavior.
command: redis-server --save "" --appendonly no
volumes:
db_volume:
vespa_volume: # Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:

View File

@@ -18,11 +18,7 @@ import AdvancedFormPage from "./pages/Advanced";
import DynamicConnectionForm from "./pages/DynamicConnectorCreationForm";
import CreateCredential from "@/components/credentials/actions/CreateCredential";
import ModifyCredential from "@/components/credentials/actions/ModifyCredential";
import {
ConfigurableSources,
oauthSupportedSources,
ValidSources,
} from "@/lib/types";
import { ConfigurableSources, oauthSupportedSources } from "@/lib/types";
import {
Credential,
credentialTemplates,
@@ -448,7 +444,7 @@ export default function AddConnector({
<CardSection>
<Title className="mb-2 text-lg">Select a credential</Title>
{connector == ValidSources.Gmail ? (
{connector == "gmail" ? (
<GmailMain />
) : (
<>

View File

@@ -949,15 +949,15 @@ export function ChatPage({
// Check if all messages are currently rendered
if (currentVisibleRange.end < messageHistory.length) {
// Update visible range to include the last messages
updateCurrentVisibleRange({
start: Math.max(
0,
messageHistory.length -
(currentVisibleRange.end - currentVisibleRange.start)
),
end: messageHistory.length,
mostVisibleMessageId: currentVisibleRange.mostVisibleMessageId,
});
// updateCurrentVisibleRange({
// start: Math.max(
// 0,
// messageHistory.length -
// (currentVisibleRange.end - currentVisibleRange.start)
// ),
// end: messageHistory.length,
// mostVisibleMessageId: currentVisibleRange.mostVisibleMessageId,
// });
// Wait for the state update and re-render before scrolling
setTimeout(() => {
@@ -1121,7 +1121,6 @@ export function ChatPage({
"Continue Generating (pick up exactly where you left off)",
});
};
const [uncaughtError, setUncaughtError] = useState<string | null>(null);
const onSubmit = async ({
messageIdToResend,
@@ -1550,23 +1549,8 @@ export function ChatPage({
}
);
} else if (Object.hasOwn(packet, "error")) {
if (
sub_questions.length > 0 &&
sub_questions
.filter((q) => q.level === 0)
.every((q) => q.is_stopped === true)
) {
setUncaughtError((packet as StreamingError).error);
updateChatState("input");
setAgenticGenerating(false);
setAlternativeGeneratingAssistant(null);
setSubmittedMessage("");
return;
// throw new Error((packet as StreamingError).error);
} else {
error = (packet as StreamingError).error;
stackTrace = (packet as StreamingError).stack_trace;
}
error = (packet as StreamingError).error;
stackTrace = (packet as StreamingError).stack_trace;
} else if (Object.hasOwn(packet, "message_id")) {
finalMessage = packet as BackendMessage;
} else if (Object.hasOwn(packet, "stop_reason")) {
@@ -1885,6 +1869,7 @@ export function ChatPage({
newRange: VisibleRange,
forceUpdate?: boolean
) => {
console.log("updateCurrentVisibleRange", newRange);
if (
scrollInitialized.current &&
visibleRange.get(loadedIdSessionRef.current) == undefined &&
@@ -1923,26 +1908,54 @@ export function ChatPage({
scrollInitialized.current = true;
}
};
const setVisibleRangeForCurrentSessionId = (newRange: VisibleRange) => {
console.log("setVisibleRangeForCurrentSessionId", newRange);
setVisibleRange((prevState) => {
const newState = new Map(prevState);
newState.set(currentSessionId(), newRange);
return newState;
});
};
const updateVisibleRangeBasedOnScroll = () => {
if (!scrollInitialized.current) return;
function updateVisibleRangeBasedOnScroll() {
const scrollableDiv = scrollableDivRef.current;
if (!scrollableDiv) return;
const viewportHeight = scrollableDiv.clientHeight;
let mostVisibleMessageIndex = -1;
const distanceFromBottom =
scrollableDiv.scrollHeight -
scrollableDiv.scrollTop -
scrollableDiv.clientHeight;
const isNearBottom = distanceFromBottom < 200;
// If user is near bottom, we treat the last message as "most visible"
if (isNearBottom) {
const startIndex = Math.max(0, messageHistory.length - BUFFER_COUNT);
const endIndex = messageHistory.length;
setVisibleRangeForCurrentSessionId({
start: startIndex,
end: endIndex,
mostVisibleMessageId:
messageHistory.length > 0
? messageHistory[messageHistory.length - 1].messageId
: null,
});
return;
}
// otherwise do the bounding rect logic:
let mostVisibleMessageIndex = -1;
const viewportHeight = scrollableDiv.clientHeight;
messageHistory.forEach((message, index) => {
const messageElement = document.getElementById(
`message-${message.messageId}`
);
if (messageElement) {
const rect = messageElement.getBoundingClientRect();
const isVisible = rect.bottom <= viewportHeight && rect.bottom > 0;
const elem = document.getElementById(`message-${message.messageId}`);
if (elem) {
const rect = elem.getBoundingClientRect();
const isVisible = rect.top < viewportHeight && rect.bottom >= 0;
if (isVisible && index > mostVisibleMessageIndex) {
mostVisibleMessageIndex = index;
}
}
// clientScrollToBottom;
});
if (mostVisibleMessageIndex !== -1) {
@@ -1951,34 +1964,50 @@ export function ChatPage({
messageHistory.length,
mostVisibleMessageIndex + BUFFER_COUNT + 1
);
updateCurrentVisibleRange({
setVisibleRangeForCurrentSessionId({
start: startIndex,
end: endIndex,
mostVisibleMessageId: messageHistory[mostVisibleMessageIndex].messageId,
});
}
};
}
useEffect(() => {
initializeVisibleRange();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [router, messageHistory]);
useLayoutEffect(() => {
const scrollableDiv = scrollableDivRef.current;
useEffect(() => {
console.log("useEffect has been called");
const scrollEl = scrollableDivRef.current;
if (!scrollEl) {
console.log("no scrollEl");
return;
}
const handleScroll = () => {
updateVisibleRangeBasedOnScroll();
requestAnimationFrame(() => {
updateVisibleRangeBasedOnScroll();
});
};
scrollableDiv?.addEventListener("scroll", handleScroll);
const attachScrollListener = () => {
if (scrollEl) {
scrollEl.addEventListener("scroll", handleScroll);
} else {
console.log("scrollEl not available, retrying in 100ms");
setTimeout(attachScrollListener, 100);
}
};
attachScrollListener();
return () => {
scrollableDiv?.removeEventListener("scroll", handleScroll);
if (scrollEl) {
scrollEl.removeEventListener("scroll", handleScroll);
}
};
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [messageHistory]);
}, [scrollableDivRef, messageHistory, currentSessionId()]);
const imageFileInMessageHistory = useMemo(() => {
return messageHistory
@@ -2055,7 +2084,6 @@ export function ChatPage({
}
const data = await response.json();
router.push(data.redirect_url);
} catch (error) {
console.error("Error seeding chat from Slack:", error);
@@ -2510,7 +2538,7 @@ export function ChatPage({
? messageHistory
: messageHistory.slice(
currentVisibleRange.start,
currentVisibleRange.end
currentVisibleRange.end + 2
)
).map((message, fauxIndex) => {
const i =
@@ -2650,7 +2678,6 @@ export function ChatPage({
{message.sub_questions &&
message.sub_questions.length > 0 ? (
<AgenticMessage
error={uncaughtError}
docSidebarToggled={
documentSidebarToggled &&
(selectedMessageForDocDisplay ==

View File

@@ -80,7 +80,6 @@ export const AgenticMessage = ({
agenticDocs,
secondLevelSubquestions,
toggleDocDisplay,
error,
}: {
docSidebarToggled?: boolean;
isImprovement?: boolean | null;
@@ -111,7 +110,6 @@ export const AgenticMessage = ({
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
setPresentingDocument?: (document: OnyxDocument) => void;
toggleDocDisplay?: (agentic: boolean) => void;
error?: string | null;
}) => {
const [noShowingMessage, setNoShowingMessage] = useState(isComplete);
@@ -485,28 +483,11 @@ export const AgenticMessage = ({
) : (
content
)}
{error && (
<p className="mt-2 text-red-700 text-sm my-auto">
{error}
</p>
)}
</div>
</div>
</>
) : isComplete ? (
error && (
<p className="mt-2 mx-4 text-red-700 text-sm my-auto">
{error}
</p>
)
) : (
<>
{error && (
<p className="mt-2 mx-4 text-red-700 text-sm my-auto">
{error}
</p>
)}
</>
) : isComplete ? null : (
<></>
)}
{handleFeedback &&
(isActive ? (

View File

@@ -185,7 +185,6 @@ export const AIMessage = ({
setPresentingDocument,
index,
toggledDocumentSidebar,
removePadding,
}: {
index?: number;
shared?: boolean;
@@ -214,7 +213,6 @@ export const AIMessage = ({
overriddenModel?: string;
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
setPresentingDocument?: (document: OnyxDocument) => void;
removePadding?: boolean;
}) => {
const toolCallGenerating = toolCall && !toolCall.tool_result;
@@ -400,9 +398,7 @@ export const AIMessage = ({
<div
id={isComplete ? "onyx-ai-message" : undefined}
ref={trackedElementRef}
className={`py-5 ml-4 lg:px-5 relative flex
${removePadding && "!pl-24 -mt-12"}`}
className={`py-5 ml-4 lg:px-5 relative flex `}
>
<div
className={`mx-auto ${
@@ -411,13 +407,11 @@ export const AIMessage = ({
>
<div className={`lg:mr-12 ${!shared && "mobile:ml-0 md:ml-8"}`}>
<div className="flex">
{!removePadding && (
<AssistantIcon
className="mobile:hidden"
size={24}
assistant={alternativeAssistant || currentPersona}
/>
)}
<AssistantIcon
className="mobile:hidden"
size={24}
assistant={alternativeAssistant || currentPersona}
/>
<div className="w-full">
<div className="max-w-message-max break-words">
@@ -594,8 +588,7 @@ export const AIMessage = ({
)}
</div>
{!removePadding &&
handleFeedback &&
{handleFeedback &&
(isActive ? (
<div
className={`

View File

@@ -772,7 +772,7 @@ For example, specifying .*-support.* as a "channel" will cause the connector to
advanced_values: [],
},
linear: {
description: "Configure Linear connector",
description: "Configure Dropbox connector",
values: [],
advanced_values: [],
},

View File

@@ -114,7 +114,7 @@ export interface LoopioCredentialJson {
}
export interface LinearCredentialJson {
linear_access_token: string;
linear_api_key: string;
}
export interface HubSpotCredentialJson {
@@ -250,7 +250,7 @@ export const credentialTemplates: Record<ValidSources, any> = {
gong_access_key_secret: "",
} as GongCredentialJson,
zulip: { zuliprc_content: "" } as ZulipCredentialJson,
linear: { linear_access_token: "" } as LinearCredentialJson,
linear: { linear_api_key: "" } as LinearCredentialJson,
hubspot: { hubspot_access_token: "" } as HubSpotCredentialJson,
document360: {
portal_id: "",
@@ -404,7 +404,7 @@ export const credentialDisplayNames: Record<string, string> = {
loopio_client_token: "Loopio Client Token",
// Linear
linear_access_token: "Linear Access Token",
linear_api_key: "Linear API Key",
// HubSpot
hubspot_access_token: "HubSpot Access Token",