mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-20 09:15:47 +00:00
Compare commits
1 Commits
agent-sear
...
virtualiza
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
852c0d7a8c |
56
.github/workflows/pr-integration-tests.yml
vendored
56
.github/workflows/pr-integration-tests.yml
vendored
@@ -94,19 +94,16 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=cloud \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
DEV_MODE=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack up -d
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Waiting for 3 minutes to ensure API server is ready..."
|
||||
sleep 180
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
--name test-runner \
|
||||
@@ -122,10 +119,6 @@ jobs:
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
-e REQUIRE_EMAIL_VERIFICATION=false \
|
||||
-e DISABLE_TELEMETRY=true \
|
||||
-e IMAGE_TAG=test \
|
||||
-e DEV_MODE=true \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
continue-on-error: true
|
||||
@@ -133,17 +126,17 @@ jobs:
|
||||
|
||||
- name: Check multi-tenant test results
|
||||
run: |
|
||||
if [ ${{ steps.run_multitenant_tests.outcome }} == 'failure' ]; then
|
||||
echo "Multi-tenant integration tests failed. Exiting with error."
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All multi-tenant integration tests passed successfully."
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack down -v
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
@@ -223,30 +216,27 @@ jobs:
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
@@ -5,6 +5,7 @@ Revises: 47e5bef3a1d7
|
||||
Create Date: 2024-11-06 13:15:53.302644
|
||||
|
||||
"""
|
||||
import logging
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
@@ -19,8 +20,13 @@ down_revision = "47e5bef3a1d7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
logger.info(f"{revision}: create_table: slack_bot")
|
||||
# Create new slack_bot table
|
||||
op.create_table(
|
||||
"slack_bot",
|
||||
@@ -57,6 +63,7 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
# Handle existing Slack bot tokens first
|
||||
logger.info(f"{revision}: Checking for existing Slack bot.")
|
||||
bot_token = None
|
||||
app_token = None
|
||||
first_row_id = None
|
||||
@@ -64,12 +71,15 @@ def upgrade() -> None:
|
||||
try:
|
||||
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
|
||||
except Exception:
|
||||
logger.warning("No existing Slack bot tokens found.")
|
||||
tokens = {}
|
||||
|
||||
bot_token = tokens.get("bot_token")
|
||||
app_token = tokens.get("app_token")
|
||||
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Found bot and app tokens.")
|
||||
|
||||
session = Session(bind=op.get_bind())
|
||||
new_slack_bot = SlackBot(
|
||||
name="Slack Bot (Migrated)",
|
||||
@@ -160,9 +170,10 @@ def upgrade() -> None:
|
||||
# Clean up old tokens if they existed
|
||||
try:
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Removing old bot and app tokens.")
|
||||
get_kv_store().delete("slack_bot_tokens_config_key")
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning("tried to delete tokens in dynamic config but failed")
|
||||
# Rename the table
|
||||
op.rename_table(
|
||||
"slack_bot_config__standard_answer_category",
|
||||
@@ -179,6 +190,8 @@ def upgrade() -> None:
|
||||
# Drop the table with CASCADE to handle dependent objects
|
||||
op.execute("DROP TABLE slack_bot_config CASCADE")
|
||||
|
||||
logger.info(f"{revision}: Migration complete.")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Recreate the old slack_bot_config table
|
||||
@@ -260,7 +273,7 @@ def downgrade() -> None:
|
||||
}
|
||||
get_kv_store().store("slack_bot_tokens_config_key", tokens)
|
||||
except Exception:
|
||||
pass
|
||||
logger.warning("Failed to save tokens back to KV store")
|
||||
|
||||
# Drop the new tables in reverse order
|
||||
op.drop_table("slack_channel_config")
|
||||
|
||||
@@ -64,7 +64,6 @@ async def _get_tenant_id_from_request(
|
||||
|
||||
try:
|
||||
# Look up token data in Redis
|
||||
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
if not token_data:
|
||||
|
||||
@@ -24,7 +24,6 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
|
||||
from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
@@ -86,8 +85,7 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
# Notify control plane
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
|
||||
@@ -9,6 +9,7 @@ class CoreState(BaseModel):
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
base_question: str = ""
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
@@ -17,4 +18,4 @@ class SubgraphCoreState(BaseModel):
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
log_messages: Annotated[list[str], add]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
@@ -12,39 +12,12 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
|
||||
SubQuestionAnswerCheckUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
binary_string_test,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import AgentLLMErrorType
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. The sub-answer will be treated as 'relevant'",
|
||||
rate_limit="LLM Rate Limit Error. The sub-answer will be treated as 'relevant'",
|
||||
general_error="General LLM Error. The sub-answer will be treated as 'relevant'",
|
||||
)
|
||||
|
||||
|
||||
def check_sub_answer(
|
||||
@@ -80,46 +53,14 @@ def check_sub_answer(
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
response: BaseMessage | None = None
|
||||
try:
|
||||
response = fast_llm.invoke(
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
|
||||
)
|
||||
)
|
||||
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - check sub answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - check sub answer")
|
||||
|
||||
if agent_error:
|
||||
answer_quality = True
|
||||
log_result = agent_error.error_result
|
||||
|
||||
else:
|
||||
if response:
|
||||
quality_str: str = cast(str, response.content)
|
||||
answer_quality = binary_string_test(
|
||||
text=quality_str, positive_value=AGENT_POSITIVE_VALUE_STR
|
||||
)
|
||||
|
||||
else:
|
||||
answer_quality = True
|
||||
quality_str = "yes - because LLM error"
|
||||
|
||||
log_result = f"Answer quality: {quality_str}"
|
||||
quality_str: str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
answer_quality = "yes" in quality_str.lower()
|
||||
|
||||
return SubQuestionAnswerCheckUpdate(
|
||||
answer_quality=answer_quality,
|
||||
@@ -128,7 +69,7 @@ def check_sub_answer(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="check sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result=log_result,
|
||||
result=f"Answer quality: {quality_str}",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -16,20 +16,6 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_sub_question_answer_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
LLM_ANSWER_ERROR_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
@@ -44,20 +30,11 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
rate_limit="LLM Rate Limit Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
general_error="General LLM Error. A sub-answer could not be constructed and the sub-question will be ignored.",
|
||||
)
|
||||
|
||||
|
||||
def generate_sub_answer(
|
||||
state: AnswerQuestionState,
|
||||
@@ -80,8 +57,6 @@ def generate_sub_answer(
|
||||
|
||||
if len(context_docs) == 0:
|
||||
answer_str = NO_RECOVERED_DOCS
|
||||
cited_documents: list = []
|
||||
log_results = "No documents retrieved"
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
@@ -104,67 +79,41 @@ def generate_sub_answer(
|
||||
|
||||
response: list[str | list[str | dict[str, Any]]] = []
|
||||
dispatch_timings: list[float] = []
|
||||
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
|
||||
try:
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate sub answer")
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate sub answer")
|
||||
response.append(content)
|
||||
|
||||
if agent_error:
|
||||
answer_str = LLM_ANSWER_ERROR_MESSAGE
|
||||
cited_documents = []
|
||||
log_results = (
|
||||
agent_error.error_result
|
||||
or "Sub-answer generation failed due to LLM error"
|
||||
)
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
logger.debug(
|
||||
f"Average dispatch time: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
)
|
||||
|
||||
else:
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
answer_citation_ids = get_answer_citation_ids(answer_str)
|
||||
cited_documents = [
|
||||
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
|
||||
]
|
||||
log_results = None
|
||||
answer_citation_ids = get_answer_citation_ids(answer_str)
|
||||
cited_documents = [
|
||||
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
|
||||
]
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
@@ -182,7 +131,7 @@ def generate_sub_answer(
|
||||
graph_component="initial - generate individual sub answer",
|
||||
node_name="generate sub answer",
|
||||
node_start_time=node_start_time,
|
||||
result=log_results or "",
|
||||
result="",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -42,8 +42,10 @@ class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):
|
||||
|
||||
|
||||
class SubQuestionAnsweringInput(SubgraphCoreState):
|
||||
question: str
|
||||
question_id: str
|
||||
question: str = ""
|
||||
question_id: str = (
|
||||
"" # 0_0 is original question, everything else is <level>_<question_num>.
|
||||
)
|
||||
# level 0 is original question and first decomposition, level 1 is follow up, etc
|
||||
# question_num is a unique number per original question per level.
|
||||
|
||||
|
||||
@@ -26,18 +26,7 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
@@ -53,16 +42,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_ci
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
|
||||
)
|
||||
@@ -72,12 +57,6 @@ from onyx.prompts.agent_search import (
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. The initial answer could not be generated.",
|
||||
rate_limit="LLM Rate Limit Error. The initial answer could not be generated.",
|
||||
general_error="General LLM Error. The initial answer could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
def generate_initial_answer(
|
||||
state: SubQuestionRetrievalState,
|
||||
@@ -245,82 +224,30 @@ def generate_initial_answer(
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
|
||||
try:
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate initial answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate initial answer")
|
||||
|
||||
if agent_error:
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
StreamingError(
|
||||
error=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=None,
|
||||
error=AgentErrorLoggingFormat(
|
||||
error_message=agent_error.error_message or "An LLM error occurred",
|
||||
error_type=agent_error.error_type,
|
||||
error_result=agent_error.error_result,
|
||||
),
|
||||
initial_agent_stats=None,
|
||||
generated_sub_questions=sub_questions,
|
||||
agent_base_end_time=None,
|
||||
agent_base_metrics=None,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="initial - generate initial answer",
|
||||
node_name="generate initial answer",
|
||||
node_start_time=node_start_time,
|
||||
result=agent_error.error_result or "An LLM error occurred",
|
||||
)
|
||||
],
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
logger.debug(
|
||||
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
|
||||
@@ -25,7 +25,7 @@ def validate_initial_answer(
|
||||
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
|
||||
)
|
||||
|
||||
verdict = True # not actually required as already streamed out. Refinement will do similar
|
||||
verdict = True
|
||||
|
||||
return InitialAnswerQualityUpdate(
|
||||
initial_answer_quality_eval=verdict,
|
||||
|
||||
@@ -23,18 +23,6 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
@@ -45,11 +33,6 @@ from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
|
||||
)
|
||||
@@ -60,12 +43,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="LLM Timeout Error. Sub-questions could not be generated.",
|
||||
rate_limit="LLM Rate Limit Error. Sub-questions could not be generated.",
|
||||
general_error="General LLM Error. Sub-questions could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
def decompose_orig_question(
|
||||
state: SubQuestionRetrievalState,
|
||||
@@ -135,35 +112,11 @@ def decompose_orig_question(
|
||||
)
|
||||
|
||||
# dispatches custom events for subquestion tokens, adding in subquestion ids.
|
||||
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
|
||||
try:
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(0, writer),
|
||||
sep_callback=dispatch_subquestion_sep(0, writer),
|
||||
)
|
||||
except LLMTimeoutError as e:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - decompose orig question")
|
||||
raise e # fail loudly on this critical step
|
||||
except LLMRateLimitError as e:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - decompose orig question")
|
||||
raise e
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(msg),
|
||||
dispatch_subquestion(0, writer),
|
||||
sep_callback=dispatch_subquestion_sep(0, writer),
|
||||
)
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
@@ -172,19 +125,19 @@ def decompose_orig_question(
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
if agent_error:
|
||||
initial_sub_questions: list[str] = []
|
||||
log_result = agent_error.error_result
|
||||
else:
|
||||
deomposition_response = merge_content(*streamed_tokens)
|
||||
deomposition_response = merge_content(*streamed_tokens)
|
||||
|
||||
list_of_subqs = cast(str, deomposition_response).split("\n")
|
||||
# this call should only return strings. Commenting out for efficiency
|
||||
# assert [type(tok) == str for tok in streamed_tokens]
|
||||
|
||||
initial_sub_questions = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
|
||||
log_result = f"decomposed original question into {len(initial_sub_questions)} subquestions"
|
||||
# use no-op cast() instead of str() which runs code
|
||||
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
|
||||
list_of_subqs = cast(str, deomposition_response).split("\n")
|
||||
|
||||
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
|
||||
|
||||
return InitialQuestionDecompositionUpdate(
|
||||
initial_sub_questions=initial_sub_questions,
|
||||
initial_sub_questions=decomp_list,
|
||||
agent_start_time=agent_start_time,
|
||||
agent_refined_start_time=None,
|
||||
agent_refined_end_time=None,
|
||||
@@ -198,7 +151,7 @@ def decompose_orig_question(
|
||||
graph_component="initial - generate sub answers",
|
||||
node_name="decompose original question",
|
||||
node_start_time=node_start_time,
|
||||
result=log_result,
|
||||
result=f"decomposed original question into {len(decomp_list)} subquestions",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -252,7 +252,9 @@ if __name__ == "__main__":
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
|
||||
inputs = MainInput(log_messages=[])
|
||||
inputs = MainInput(
|
||||
base_question=graph_config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
@@ -11,37 +10,14 @@ from onyx.agents.agent_search.deep_search.main.states import (
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out, and the answers could not be compared.",
|
||||
rate_limit="The LLM encountered a rate limit, and the answers could not be compared.",
|
||||
general_error="The LLM encountered an error, and the answers could not be compared.",
|
||||
)
|
||||
|
||||
|
||||
def compare_answers(
|
||||
@@ -64,46 +40,15 @@ def compare_answers(
|
||||
|
||||
msg = [HumanMessage(content=compare_answers_prompt)]
|
||||
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
# Get the rewritten queries in a defined format
|
||||
model = graph_config.tooling.fast_llm
|
||||
resp: BaseMessage | None = None
|
||||
refined_answer_improvement: bool | None = None
|
||||
|
||||
# no need to stream this
|
||||
try:
|
||||
resp = model.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
resp = model.invoke(msg)
|
||||
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - compare answers")
|
||||
# continue as True in this support step
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - compare answers")
|
||||
# continue as True in this support step
|
||||
|
||||
if agent_error or resp is None:
|
||||
refined_answer_improvement = True
|
||||
if agent_error:
|
||||
log_result = agent_error.error_result
|
||||
else:
|
||||
log_result = "An answer could not be generated."
|
||||
|
||||
else:
|
||||
refined_answer_improvement = (
|
||||
isinstance(resp.content, str) and "yes" in resp.content.lower()
|
||||
)
|
||||
log_result = f"Answer comparison: {refined_answer_improvement}"
|
||||
refined_answer_improvement = (
|
||||
isinstance(resp.content, str) and "yes" in resp.content.lower()
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
"refined_answer_improvement",
|
||||
@@ -120,7 +65,7 @@ def compare_answers(
|
||||
graph_component="main",
|
||||
node_name="compare answers",
|
||||
node_start_time=node_start_time,
|
||||
result=log_result,
|
||||
result=f"Answer comparison: {refined_answer_improvement}",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -21,18 +21,6 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
format_entity_term_extraction,
|
||||
@@ -42,25 +30,10 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT,
|
||||
)
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out. The sub-questions could not be generated.",
|
||||
rate_limit="The LLM encountered a rate limit. The sub-questions could not be generated.",
|
||||
general_error="The LLM encountered an error. The sub-questions could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
def create_refined_sub_questions(
|
||||
@@ -123,65 +96,29 @@ def create_refined_sub_questions(
|
||||
# Grader
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
try:
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(1, writer),
|
||||
sep_callback=dispatch_subquestion_sep(1, writer),
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - create refined sub questions")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - create refined sub questions")
|
||||
|
||||
if agent_error:
|
||||
refined_sub_question_dict: dict[int, RefinementSubQuestion] = {}
|
||||
log_result = agent_error.error_result
|
||||
write_custom_event(
|
||||
"refined_sub_question_creation_error",
|
||||
StreamingError(
|
||||
error="Your LLM was not able to create refined sub questions in time and timed out. Please try again.",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(msg),
|
||||
dispatch_subquestion(1, writer),
|
||||
sep_callback=dispatch_subquestion_sep(1, writer),
|
||||
)
|
||||
response = merge_content(*streamed_tokens)
|
||||
|
||||
if isinstance(response, str):
|
||||
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
|
||||
else:
|
||||
response = merge_content(*streamed_tokens)
|
||||
raise ValueError("LLM response is not a string")
|
||||
|
||||
if isinstance(response, str):
|
||||
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
|
||||
else:
|
||||
raise ValueError("LLM response is not a string")
|
||||
refined_sub_question_dict = {}
|
||||
for sub_question_num, sub_question in enumerate(parsed_response):
|
||||
refined_sub_question = RefinementSubQuestion(
|
||||
sub_question=sub_question,
|
||||
sub_question_id=make_question_id(1, sub_question_num + 1),
|
||||
verified=False,
|
||||
answered=False,
|
||||
answer="",
|
||||
)
|
||||
|
||||
refined_sub_question_dict = {}
|
||||
for sub_question_num, sub_question in enumerate(parsed_response):
|
||||
refined_sub_question = RefinementSubQuestion(
|
||||
sub_question=sub_question,
|
||||
sub_question_id=make_question_id(1, sub_question_num + 1),
|
||||
verified=False,
|
||||
answered=False,
|
||||
answer="",
|
||||
)
|
||||
|
||||
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
|
||||
|
||||
log_result = f"Created {len(refined_sub_question_dict)} refined sub questions"
|
||||
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
|
||||
|
||||
return RefinedQuestionDecompositionUpdate(
|
||||
refined_sub_questions=refined_sub_question_dict,
|
||||
@@ -191,7 +128,7 @@ def create_refined_sub_questions(
|
||||
graph_component="main",
|
||||
node_name="create refined sub questions",
|
||||
node_start_time=node_start_time,
|
||||
result=log_result,
|
||||
result=f"Created {len(refined_sub_question_dict)} refined sub questions",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -26,19 +26,6 @@ def decide_refinement_need(
|
||||
|
||||
decision = True # TODO: just for current testing purposes
|
||||
|
||||
if state.error:
|
||||
return RequireRefinemenEvalUpdate(
|
||||
require_refined_answer_eval=False,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="decide refinement need",
|
||||
node_start_time=node_start_time,
|
||||
result="Timeout Error",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
log_messages = [
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
|
||||
@@ -21,9 +21,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
|
||||
@@ -84,7 +81,6 @@ def extract_entities_terms(
|
||||
# Grader
|
||||
llm_response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.agents.agent_search.deep_search.main.models import (
|
||||
AgentRefinedMetrics,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
|
||||
from onyx.agents.agent_search.deep_search.main.operations import logger
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
RefinedAnswerUpdate,
|
||||
@@ -22,18 +23,7 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InferenceSection
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
@@ -53,14 +43,8 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
|
||||
)
|
||||
@@ -72,15 +56,6 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out. The refined answer could not be generated.",
|
||||
rate_limit="The LLM encountered a rate limit. The refined answer could not be generated.",
|
||||
general_error="The LLM encountered an error. The refined answer could not be generated.",
|
||||
)
|
||||
|
||||
|
||||
def generate_refined_answer(
|
||||
@@ -256,80 +231,28 @@ def generate_refined_answer(
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
|
||||
try:
|
||||
for message in model.stream(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=1,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
for message in model.stream(msg):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - generate refined answer")
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - generate refined answer")
|
||||
|
||||
if agent_error:
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
StreamingError(
|
||||
error=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=1,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return RefinedAnswerUpdate(
|
||||
refined_answer=None,
|
||||
refined_answer_quality=False, # TODO: replace this with the actual check value
|
||||
refined_agent_stats=None,
|
||||
agent_refined_end_time=None,
|
||||
agent_refined_metrics=AgentRefinedMetrics(
|
||||
refined_doc_boost_factor=0.0,
|
||||
refined_question_boost_factor=0.0,
|
||||
duration_s=None,
|
||||
),
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="generate refined answer",
|
||||
node_start_time=node_start_time,
|
||||
result=agent_error.error_result or "An LLM error occurred",
|
||||
)
|
||||
],
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
logger.debug(
|
||||
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
@@ -343,6 +266,49 @@ def generate_refined_answer(
|
||||
revision_question_efficiency=revision_question_efficiency,
|
||||
)
|
||||
|
||||
logger.debug(f"\n\n---INITIAL ANSWER ---\n\n Answer:\n Agent: {initial_answer}")
|
||||
logger.debug("-" * 10)
|
||||
logger.debug(f"\n\n---REVISED AGENT ANSWER ---\n\n Answer:\n Agent: {answer}")
|
||||
|
||||
logger.debug("-" * 100)
|
||||
|
||||
if state.initial_agent_stats:
|
||||
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio", "--"
|
||||
)
|
||||
initial_support_boost_factor = (
|
||||
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
|
||||
)
|
||||
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
|
||||
"num_verified_documents", "--"
|
||||
)
|
||||
initial_verified_docs_avg_score = (
|
||||
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
|
||||
)
|
||||
initial_sub_questions_verified_docs = (
|
||||
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
|
||||
)
|
||||
|
||||
logger.debug("INITIAL AGENT STATS")
|
||||
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
|
||||
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
|
||||
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
|
||||
logger.debug(
|
||||
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
|
||||
)
|
||||
if refined_agent_stats:
|
||||
logger.debug("-" * 10)
|
||||
logger.debug("REFINED AGENT STATS")
|
||||
logger.debug(
|
||||
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
|
||||
)
|
||||
|
||||
agent_refined_end_time = datetime.now()
|
||||
if state.agent_refined_start_time:
|
||||
agent_refined_duration = (
|
||||
|
||||
@@ -17,7 +17,6 @@ from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
@@ -77,7 +76,6 @@ class InitialAnswerUpdate(LoggerUpdate):
|
||||
"""
|
||||
|
||||
initial_answer: str | None = None
|
||||
error: AgentErrorLoggingFormat | None = None
|
||||
initial_agent_stats: InitialAgentResultStats | None = None
|
||||
generated_sub_questions: list[str] = []
|
||||
agent_base_end_time: datetime | None = None
|
||||
@@ -90,7 +88,6 @@ class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
|
||||
"""
|
||||
|
||||
refined_answer: str | None = None
|
||||
error: AgentErrorLoggingFormat | None = None
|
||||
refined_agent_stats: RefinedAgentStats | None = None
|
||||
refined_answer_quality: bool = False
|
||||
|
||||
|
||||
@@ -16,40 +16,14 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
||||
QueryExpansionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
QUERY_REWRITING_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="Query rewriting failed due to LLM timeout - the original question will be used.",
|
||||
rate_limit="Query rewriting failed due to LLM rate limit - the original question will be used.",
|
||||
general_error="Query rewriting failed due to LLM error - the original question will be used.",
|
||||
)
|
||||
|
||||
|
||||
def expand_queries(
|
||||
@@ -80,43 +54,13 @@ def expand_queries(
|
||||
)
|
||||
]
|
||||
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
llm_response_list: list[BaseMessage_Content] = []
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(prompt=msg), dispatch_subquery(level, question_num, writer)
|
||||
)
|
||||
|
||||
try:
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
),
|
||||
dispatch_subquery(level, question_num, writer),
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - expand queries")
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
except LLMRateLimitError:
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - expand queries")
|
||||
# use subquestion as query if query generation fails
|
||||
if agent_error:
|
||||
llm_response = ""
|
||||
rewritten_queries = [question]
|
||||
log_result = agent_error.error_result
|
||||
else:
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[
|
||||
0
|
||||
].content
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=rewritten_queries,
|
||||
@@ -125,7 +69,7 @@ def expand_queries(
|
||||
graph_component="shared - expanded retrieval",
|
||||
node_name="expand queries",
|
||||
node_start_time=node_start_time,
|
||||
result=log_result,
|
||||
result=f"Number of expanded queries: {len(rewritten_queries)}",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
@@ -11,41 +10,12 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
||||
DocVerificationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
binary_string_test,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AGENT_POSITIVE_VALUE_STR,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
DOCUMENT_VERIFICATION_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
timeout="The LLM timed out. The document could not be verified. The document will be treated as 'relevant'",
|
||||
rate_limit="The LLM encountered a rate limit. The document could not be verified. The document will be treated as 'relevant'",
|
||||
general_error="The LLM encountered an error. The document could not be verified. The document will be treated as 'relevant'",
|
||||
)
|
||||
|
||||
|
||||
def verify_documents(
|
||||
@@ -56,7 +26,7 @@ def verify_documents(
|
||||
|
||||
Args:
|
||||
state (DocVerificationInput): The current state
|
||||
config (RunnableConfig): Configuration containing AgentSearchConfig
|
||||
config (RunnableConfig): Configuration containing ProSearchConfig
|
||||
|
||||
Updates:
|
||||
verified_documents: list[InferenceSection]
|
||||
@@ -81,42 +51,11 @@ def verify_documents(
|
||||
)
|
||||
]
|
||||
|
||||
agent_error: AgentErrorLoggingFormat | None = None
|
||||
response: BaseMessage | None = None
|
||||
response = fast_llm.invoke(msg)
|
||||
|
||||
try:
|
||||
response = fast_llm.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
|
||||
except LLMTimeoutError:
|
||||
# In this case, we decide to continue and don't raise an error, as
|
||||
# little harm in letting some docs through that are less relevant.
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.timeout,
|
||||
)
|
||||
logger.error("LLM Timeout Error - verify documents")
|
||||
except LLMRateLimitError:
|
||||
# In this case, we decide to continue and don't raise an error, as
|
||||
# little harm in letting some docs through that are less relevant.
|
||||
agent_error = AgentErrorLoggingFormat(
|
||||
error_type=AgentLLMErrorType.RATE_LIMIT,
|
||||
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
|
||||
error_result=_llm_node_error_strings.rate_limit,
|
||||
)
|
||||
logger.error("LLM Rate Limit Error - verify documents")
|
||||
|
||||
if agent_error or response is None:
|
||||
verified_documents = [retrieved_document_to_verify]
|
||||
|
||||
else:
|
||||
verified_documents = []
|
||||
if isinstance(response.content, str) and binary_string_test(
|
||||
text=response.content, positive_value=AGENT_POSITIVE_VALUE_STR
|
||||
):
|
||||
verified_documents.append(retrieved_document_to_verify)
|
||||
verified_documents = []
|
||||
if isinstance(response.content, str) and "yes" in response.content.lower():
|
||||
verified_documents.append(retrieved_document_to_verify)
|
||||
|
||||
return DocVerificationUpdate(
|
||||
verified_documents=verified_documents,
|
||||
|
||||
@@ -21,13 +21,9 @@ from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class ExpandedRetrievalInput(SubgraphCoreState):
|
||||
# exception from 'no default value'for LangGraph input states
|
||||
# Here, sub_question_id default Nonoe implies usage for the
|
||||
# original question. This is sometimes needed for nested sub-graphs
|
||||
|
||||
question: str = ""
|
||||
base_search: bool = False
|
||||
sub_question_id: str | None = None
|
||||
question: str
|
||||
base_search: bool
|
||||
|
||||
|
||||
## Update/Return States
|
||||
@@ -92,4 +88,4 @@ class DocVerificationInput(ExpandedRetrievalInput):
|
||||
|
||||
|
||||
class RetrievalInput(ExpandedRetrievalInput):
|
||||
query_to_retrieve: str
|
||||
query_to_retrieve: str = ""
|
||||
|
||||
@@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||
main_graph_builder as main_graph_builder_a,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
MainInput as MainInput,
|
||||
MainInput as MainInput_a,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
@@ -21,7 +21,6 @@ from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
@@ -34,7 +33,6 @@ from onyx.llm.factory import get_default_llms
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_COMPILED_GRAPH: CompiledStateGraph | None = None
|
||||
@@ -74,15 +72,13 @@ def _parse_agent_event(
|
||||
return cast(AnswerPacket, event["data"])
|
||||
elif event["name"] == "refined_answer_improvement":
|
||||
return cast(RefinedAnswerImprovement, event["data"])
|
||||
elif event["name"] == "refined_sub_question_creation_error":
|
||||
return cast(StreamingError, event["data"])
|
||||
return None
|
||||
|
||||
|
||||
def manage_sync_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
graph_input: BasicInput | MainInput,
|
||||
graph_input: BasicInput | MainInput_a,
|
||||
) -> Iterable[StreamEvent]:
|
||||
message_id = config.persistence.message_id if config.persistence else None
|
||||
for event in compiled_graph.stream(
|
||||
@@ -96,7 +92,7 @@ def manage_sync_streaming(
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
input: BasicInput | MainInput,
|
||||
input: BasicInput | MainInput_a,
|
||||
) -> AnswerStream:
|
||||
config.behavior.perform_initial_search_decomposition = (
|
||||
INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
@@ -127,7 +123,9 @@ def run_main_graph(
|
||||
) -> AnswerStream:
|
||||
compiled_graph = load_compiled_graph()
|
||||
|
||||
input = MainInput(log_messages=[])
|
||||
input = MainInput_a(
|
||||
base_question=config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
# Agent search is not a Tool per se, but this is helpful for the frontend
|
||||
yield ToolCallKickoff(
|
||||
@@ -174,7 +172,9 @@ if __name__ == "__main__":
|
||||
# search_request.persona = get_persona_by_id(1, None, db_session)
|
||||
# config.perform_initial_search_path_decision = False
|
||||
config.behavior.perform_initial_search_decomposition = True
|
||||
input = MainInput(log_messages=[])
|
||||
input = MainInput_a(
|
||||
base_question=config.inputs.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
tool_responses: list = []
|
||||
for output in run_graph(compiled_graph, config, input):
|
||||
|
||||
@@ -150,17 +150,3 @@ def get_prompt_enrichment_components(
|
||||
history=history,
|
||||
date_str=date_str,
|
||||
)
|
||||
|
||||
|
||||
def binary_string_test(text: str, positive_value: str = "yes") -> bool:
|
||||
"""
|
||||
Tests if a string contains a positive value (case-insensitive).
|
||||
|
||||
Args:
|
||||
text: The string to test
|
||||
positive_value: The value to look for (defaults to "yes")
|
||||
|
||||
Returns:
|
||||
True if the positive value is found in the text
|
||||
"""
|
||||
return positive_value.lower() in text.lower()
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
AGENT_LLM_TIMEOUT_MESSAGE = "The agent timed out. Please try again."
|
||||
AGENT_LLM_ERROR_MESSAGE = "The agent encountered an error. Please try again."
|
||||
AGENT_LLM_RATELIMIT_MESSAGE = (
|
||||
"The agent encountered a rate limit error. Please try again."
|
||||
)
|
||||
LLM_ANSWER_ERROR_MESSAGE = "The question was not answered due to an LLM error."
|
||||
|
||||
AGENT_POSITIVE_VALUE_STR = "yes"
|
||||
AGENT_NEGATIVE_VALUE_STR = "no"
|
||||
|
||||
|
||||
class AgentLLMErrorType(str, Enum):
|
||||
TIMEOUT = "timeout"
|
||||
RATE_LIMIT = "rate_limit"
|
||||
GENERAL_ERROR = "general_error"
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.models import (
|
||||
@@ -58,12 +56,6 @@ class InitialAgentResultStats(BaseModel):
|
||||
agent_effectiveness: dict[str, float | int | None]
|
||||
|
||||
|
||||
class AgentErrorLoggingFormat(BaseModel):
|
||||
error_message: str
|
||||
error_type: str
|
||||
error_result: str | None = None
|
||||
|
||||
|
||||
class RefinedAgentStats(BaseModel):
|
||||
revision_doc_efficiency: float | None
|
||||
revision_question_efficiency: float | None
|
||||
@@ -134,12 +126,3 @@ class AgentPromptEnrichmentComponents(BaseModel):
|
||||
persona_prompts: PersonaPromptExpressions
|
||||
history: str
|
||||
date_str: str
|
||||
|
||||
|
||||
class LLMNodeErrorStrings(BaseModel):
|
||||
timeout: str = "LLM Timeout Error"
|
||||
rate_limit: str = "LLM Rate Limit Error"
|
||||
general_error: str = "General LLM Error"
|
||||
|
||||
|
||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
|
||||
@@ -20,7 +20,6 @@ from onyx.agents.agent_search.models import GraphInputs
|
||||
from onyx.agents.agent_search.models import GraphPersistence
|
||||
from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
@@ -35,9 +34,6 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
@@ -50,8 +46,6 @@ from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.persona import Persona
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.prompts.agent_search import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
@@ -71,9 +65,8 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
|
||||
|
||||
# Post-processing
|
||||
@@ -379,24 +372,8 @@ def summarize_history(
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
history_response = llm.invoke(
|
||||
history_context_prompt,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
logger.error("LLM Timeout Error - summarize history")
|
||||
return (
|
||||
history # this is what is done at this point anyway, so we default to this
|
||||
)
|
||||
except LLMRateLimitError:
|
||||
logger.error("LLM Rate Limit Error - summarize history")
|
||||
return (
|
||||
history # this is what is done at this point anyway, so we default to this
|
||||
)
|
||||
|
||||
history_response = llm.invoke(history_context_prompt)
|
||||
assert isinstance(history_response.content, str)
|
||||
|
||||
return history_response.content
|
||||
|
||||
|
||||
|
||||
@@ -13,21 +13,6 @@ AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
|
||||
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
|
||||
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = 30 # in seconds
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = 10 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = 25 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = 3 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = 12 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = 25 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = 8 # in seconds
|
||||
|
||||
#####
|
||||
# Agent Configs
|
||||
#####
|
||||
@@ -92,76 +77,4 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
|
||||
or AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH
|
||||
) # 2000
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
) # 3
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION
|
||||
) # 30
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
|
||||
) # 12
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
) # 6
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION
|
||||
) # 1
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION
|
||||
) # 4
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
) # 8
|
||||
|
||||
|
||||
GRAPH_VERSION_NAME: str = "a"
|
||||
|
||||
@@ -263,11 +263,6 @@ class PostgresAdvisoryLocks(Enum):
|
||||
|
||||
|
||||
class OnyxCeleryQueues:
|
||||
# "celery" is the default queue defined by celery and also the queue
|
||||
# we are running in the primary worker to run system tasks
|
||||
# Tasks running in this queue should be designed specifically to run quickly
|
||||
PRIMARY = "celery"
|
||||
|
||||
# Light queue
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
||||
|
||||
@@ -91,7 +91,6 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
f"&response_type=code"
|
||||
f"&scope=read"
|
||||
f"&state={state}"
|
||||
f"&prompt=consent" # prompts user for access; allows choosing workspace
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -50,18 +50,6 @@ litellm.telemetry = False
|
||||
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
|
||||
|
||||
|
||||
class LLMTimeoutError(Exception):
|
||||
"""
|
||||
Exception raised when an LLM call times out.
|
||||
"""
|
||||
|
||||
|
||||
class LLMRateLimitError(Exception):
|
||||
"""
|
||||
Exception raised when an LLM call is rate limited.
|
||||
"""
|
||||
|
||||
|
||||
def _base_msg_to_role(msg: BaseMessage) -> str:
|
||||
if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk):
|
||||
return "user"
|
||||
@@ -392,7 +380,6 @@ class DefaultMultiLLM(LLM):
|
||||
tool_choice: ToolChoiceOptions | None,
|
||||
stream: bool,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
|
||||
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
|
||||
# to a dict representation
|
||||
@@ -418,7 +405,7 @@ class DefaultMultiLLM(LLM):
|
||||
stream=stream,
|
||||
# model params
|
||||
temperature=0,
|
||||
timeout=timeout_override or self._timeout,
|
||||
timeout=self._timeout,
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
@@ -437,12 +424,6 @@ class DefaultMultiLLM(LLM):
|
||||
except Exception as e:
|
||||
self._record_error(processed_prompt, e)
|
||||
# for break pointing
|
||||
if isinstance(e, litellm.Timeout):
|
||||
raise LLMTimeoutError(e)
|
||||
|
||||
elif isinstance(e, litellm.RateLimitError):
|
||||
raise LLMRateLimitError(e)
|
||||
|
||||
raise e
|
||||
|
||||
@property
|
||||
@@ -463,7 +444,6 @@ class DefaultMultiLLM(LLM):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> BaseMessage:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
@@ -471,12 +451,7 @@ class DefaultMultiLLM(LLM):
|
||||
response = cast(
|
||||
litellm.ModelResponse,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=False,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
prompt, tools, tool_choice, False, structured_response_format
|
||||
),
|
||||
)
|
||||
choice = response.choices[0]
|
||||
@@ -494,31 +469,19 @@ class DefaultMultiLLM(LLM):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
if DISABLE_LITELLM_STREAMING:
|
||||
yield self.invoke(
|
||||
prompt,
|
||||
tools,
|
||||
tool_choice,
|
||||
structured_response_format,
|
||||
timeout_override,
|
||||
)
|
||||
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
|
||||
return
|
||||
|
||||
output = None
|
||||
response = cast(
|
||||
litellm.CustomStreamWrapper,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
prompt, tools, tool_choice, True, structured_response_format
|
||||
),
|
||||
)
|
||||
try:
|
||||
|
||||
@@ -81,7 +81,6 @@ class CustomModelServer(LLM):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> BaseMessage:
|
||||
return self._execute(prompt)
|
||||
|
||||
@@ -91,6 +90,5 @@ class CustomModelServer(LLM):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
yield self._execute(prompt)
|
||||
|
||||
@@ -90,13 +90,12 @@ class LLM(abc.ABC):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> BaseMessage:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._invoke_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format, timeout_override
|
||||
prompt, tools, tool_choice, structured_response_format
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -106,7 +105,6 @@ class LLM(abc.ABC):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> BaseMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -116,13 +114,12 @@ class LLM(abc.ABC):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
messages = self._stream_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format, timeout_override
|
||||
prompt, tools, tool_choice, structured_response_format
|
||||
)
|
||||
|
||||
tokens = []
|
||||
@@ -141,6 +138,5 @@ class LLM(abc.ABC):
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -5,6 +5,8 @@ UNKNOWN_ANSWER = "I do not have enough information to answer this question."
|
||||
NO_RECOVERED_DOCS = "No relevant information recovered"
|
||||
YES = "yes"
|
||||
NO = "no"
|
||||
|
||||
|
||||
# Framing/Support/Template Prompts
|
||||
HISTORY_FRAMING_PROMPT = f"""
|
||||
For more context, here is the history of the conversation so far that preceded this question:
|
||||
|
||||
@@ -22,8 +22,6 @@ from onyx.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.connector_credential_pair import add_credential_to_connector
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pair_from_id_for_user,
|
||||
@@ -230,13 +228,6 @@ def update_cc_pair_status(
|
||||
|
||||
db_session.commit()
|
||||
|
||||
# this speeds up the start of indexing by firing the check immediately
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
kwargs=dict(tenant_id=tenant_id),
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)}
|
||||
)
|
||||
@@ -549,14 +540,7 @@ def associate_credential_to_connector(
|
||||
metadata: ConnectorCredentialPairMetadata,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse[int]:
|
||||
"""NOTE(rkuo): internally discussed and the consensus is this endpoint
|
||||
and create_connector_with_mock_credential should be combined.
|
||||
|
||||
The intent of this endpoint is to handle connectors that actually need credentials.
|
||||
"""
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
)(
|
||||
@@ -579,18 +563,6 @@ def associate_credential_to_connector(
|
||||
groups=metadata.groups,
|
||||
)
|
||||
|
||||
# trigger indexing immediately
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"associate_credential_to_connector - running check_for_indexing: "
|
||||
f"cc_pair={response.data}"
|
||||
)
|
||||
|
||||
return response
|
||||
except IntegrityError as e:
|
||||
logger.error(f"IntegrityError: {e}")
|
||||
|
||||
@@ -804,14 +804,6 @@ def create_connector_with_mock_credential(
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse:
|
||||
"""NOTE(rkuo): internally discussed and the consensus is this endpoint
|
||||
and associate_credential_to_connector should be combined.
|
||||
|
||||
The intent of this endpoint is to handle connectors that don't need credentials,
|
||||
AKA web, file, etc ... but there isn't any reason a single endpoint couldn't
|
||||
server this purpose.
|
||||
"""
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
)(
|
||||
@@ -849,18 +841,6 @@ def create_connector_with_mock_credential(
|
||||
groups=connector_data.groups,
|
||||
)
|
||||
|
||||
# trigger indexing immediately
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"create_connector_with_mock_credential - running check_for_indexing: "
|
||||
f"cc_pair={response.data}"
|
||||
)
|
||||
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email if user else tenant_id or "N/A",
|
||||
@@ -1025,8 +1005,6 @@ def connector_run_once(
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
logger.info("connector_run_once - running check_for_indexing")
|
||||
|
||||
msg = f"Marked {num_triggers} index attempts with indexing triggers."
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
|
||||
@@ -179,10 +179,12 @@ def oauth_callback(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# TODO: use a library for url handling
|
||||
sep = "&" if "?" in desired_return_url else "?"
|
||||
return CallbackResponse(
|
||||
redirect_url=f"{desired_return_url}{sep}credentialId={credential.id}"
|
||||
redirect_url=(
|
||||
f"{desired_return_url}?credentialId={credential.id}"
|
||||
if "?" not in desired_return_url
|
||||
else f"{desired_return_url}&credentialId={credential.id}"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,15 +6,11 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.document_set import check_document_sets_are_public
|
||||
from onyx.db.document_set import fetch_all_document_sets_for_user
|
||||
from onyx.db.document_set import insert_document_set
|
||||
from onyx.db.document_set import mark_document_set_as_to_be_deleted
|
||||
from onyx.db.document_set import update_document_set
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.features.document_set.models import CheckDocSetPublicRequest
|
||||
@@ -33,7 +29,6 @@ def create_document_set(
|
||||
document_set_creation_request: DocumentSetCreationRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> int:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
@@ -51,13 +46,6 @@ def create_document_set(
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
return document_set_db_model.id
|
||||
|
||||
|
||||
@@ -66,7 +54,6 @@ def patch_document_set(
|
||||
document_set_update_request: DocumentSetUpdateRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
@@ -85,19 +72,12 @@ def patch_document_set(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/admin/document-set/{document_set_id}")
|
||||
def delete_document_set(
|
||||
document_set_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
try:
|
||||
mark_document_set_as_to_be_deleted(
|
||||
@@ -108,12 +88,6 @@ def delete_document_set(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
|
||||
"""Endpoints for non-admins"""
|
||||
|
||||
|
||||
@@ -197,11 +197,6 @@ def create_deletion_attempt_for_connector_id(
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"create_deletion_attempt_for_connector_id - running check_for_connector_deletion: "
|
||||
f"cc_pair={cc_pair.id}"
|
||||
)
|
||||
|
||||
if cc_pair.connector.source == DocumentSource.FILE:
|
||||
connector = cc_pair.connector
|
||||
file_store = get_default_file_store(db_session)
|
||||
|
||||
@@ -34,7 +34,6 @@ from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
@@ -287,7 +286,7 @@ def bulk_invite_users(
|
||||
detail=f"Invalid email address: {email} - {str(e)}",
|
||||
)
|
||||
|
||||
if MULTI_TENANT and not DEV_MODE:
|
||||
if MULTI_TENANT:
|
||||
try:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning", "add_users_to_tenant", None
|
||||
|
||||
@@ -70,8 +70,8 @@ COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
# Set up application files
|
||||
COPY ./onyx /app/onyx
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
COPY ./alembic_tenants /app/alembic_tenants
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic_tenants /app/alembic_tenants
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY ./pytest.ini /app/pytest.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
||||
@@ -24,6 +24,35 @@ def generate_auth_token() -> str:
|
||||
|
||||
|
||||
class TenantManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
tenant_id: str | None = None,
|
||||
initial_admin_email: str | None = None,
|
||||
referral_source: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
body = {
|
||||
"tenant_id": tenant_id,
|
||||
"initial_admin_email": initial_admin_email,
|
||||
"referral_source": referral_source,
|
||||
}
|
||||
|
||||
token = generate_auth_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"X-API-KEY": "",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/tenants/create",
|
||||
json=body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def get_all_users(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
|
||||
@@ -92,7 +92,6 @@ class UserManager:
|
||||
|
||||
# Set cookies in the headers
|
||||
test_user.headers["Cookie"] = f"fastapiusersauth={session_cookie}; "
|
||||
test_user.cookies = {"fastapiusersauth": session_cookie}
|
||||
return test_user
|
||||
|
||||
@staticmethod
|
||||
@@ -103,7 +102,6 @@ class UserManager:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/me",
|
||||
headers=user_to_verify.headers,
|
||||
cookies=user_to_verify.cookies,
|
||||
)
|
||||
|
||||
if user_to_verify.is_active is False:
|
||||
|
||||
@@ -242,18 +242,6 @@ def reset_postgres_multitenant() -> None:
|
||||
schema_name = schema[0]
|
||||
cur.execute(f'DROP SCHEMA "{schema_name}" CASCADE')
|
||||
|
||||
# Drop tables in the public schema
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT tablename FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
"""
|
||||
)
|
||||
public_tables = cur.fetchall()
|
||||
for table in public_tables:
|
||||
table_name = table[0]
|
||||
cur.execute(f'DROP TABLE IF EXISTS public."{table_name}" CASCADE')
|
||||
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
@@ -44,7 +44,6 @@ class DATestUser(BaseModel):
|
||||
headers: dict
|
||||
role: UserRole
|
||||
is_active: bool
|
||||
cookies: dict = {}
|
||||
|
||||
|
||||
class DATestPersonaLabel(BaseModel):
|
||||
|
||||
@@ -4,6 +4,7 @@ from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.tenant import TenantManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
@@ -12,28 +13,25 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin and also proviions the tenant
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email="admin@onyx-test.com",
|
||||
)
|
||||
|
||||
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
|
||||
# Create Tenant 1 and its Admin User
|
||||
TenantManager.create("tenant_dev1", "test1@test.com", "Data Plane Registration")
|
||||
test_user1: DATestUser = UserManager.create(name="test1", email="test1@test.com")
|
||||
assert UserManager.is_role(test_user1, UserRole.ADMIN)
|
||||
|
||||
# Create Tenant 2 and its Admin User
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email="admin2@onyx-test.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user2, UserRole.ADMIN)
|
||||
TenantManager.create("tenant_dev2", "test2@test.com", "Data Plane Registration")
|
||||
test_user2: DATestUser = UserManager.create(name="test2", email="test2@test.com")
|
||||
assert UserManager.is_role(test_user2, UserRole.ADMIN)
|
||||
|
||||
# Create connectors for Tenant 1
|
||||
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user1,
|
||||
user_performing_action=test_user1,
|
||||
)
|
||||
api_key_1: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user1,
|
||||
user_performing_action=test_user1,
|
||||
)
|
||||
api_key_1.headers.update(admin_user1.headers)
|
||||
LLMProviderManager.create(user_performing_action=admin_user1)
|
||||
api_key_1.headers.update(test_user1.headers)
|
||||
LLMProviderManager.create(user_performing_action=test_user1)
|
||||
|
||||
# Seed documents for Tenant 1
|
||||
cc_pair_1.documents = []
|
||||
@@ -51,13 +49,13 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
|
||||
# Create connectors for Tenant 2
|
||||
cc_pair_2: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user2,
|
||||
user_performing_action=test_user2,
|
||||
)
|
||||
api_key_2: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user2,
|
||||
user_performing_action=test_user2,
|
||||
)
|
||||
api_key_2.headers.update(admin_user2.headers)
|
||||
LLMProviderManager.create(user_performing_action=admin_user2)
|
||||
api_key_2.headers.update(test_user2.headers)
|
||||
LLMProviderManager.create(user_performing_action=test_user2)
|
||||
|
||||
# Seed documents for Tenant 2
|
||||
cc_pair_2.documents = []
|
||||
@@ -78,17 +76,17 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
|
||||
# Create chat sessions for each user
|
||||
chat_session1: DATestChatSession = ChatSessionManager.create(
|
||||
user_performing_action=admin_user1
|
||||
user_performing_action=test_user1
|
||||
)
|
||||
chat_session2: DATestChatSession = ChatSessionManager.create(
|
||||
user_performing_action=admin_user2
|
||||
user_performing_action=test_user2
|
||||
)
|
||||
|
||||
# User 1 sends a message and gets a response
|
||||
response1 = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session1.id,
|
||||
message="What is in Tenant 1's documents?",
|
||||
user_performing_action=admin_user1,
|
||||
user_performing_action=test_user1,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response1.tool_name == "run_search"
|
||||
@@ -102,16 +100,14 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
), "Tenant 2 document IDs should not be in the response"
|
||||
|
||||
# Assert that the contents are correct
|
||||
assert any(
|
||||
doc["content"] == "Tenant 1 Document Content"
|
||||
for doc in response1.tool_result or []
|
||||
), "Tenant 1 Document Content not found in any document"
|
||||
for doc in response1.tool_result or []:
|
||||
assert doc["content"] == "Tenant 1 Document Content"
|
||||
|
||||
# User 2 sends a message and gets a response
|
||||
response2 = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session2.id,
|
||||
message="What is in Tenant 2's documents?",
|
||||
user_performing_action=admin_user2,
|
||||
user_performing_action=test_user2,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response2.tool_name == "run_search"
|
||||
@@ -123,18 +119,15 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
assert not response_doc_ids.intersection(
|
||||
tenant1_doc_ids
|
||||
), "Tenant 1 document IDs should not be in the response"
|
||||
|
||||
# Assert that the contents are correct
|
||||
assert any(
|
||||
doc["content"] == "Tenant 2 Document Content"
|
||||
for doc in response2.tool_result or []
|
||||
), "Tenant 2 Document Content not found in any document"
|
||||
for doc in response2.tool_result or []:
|
||||
assert doc["content"] == "Tenant 2 Document Content"
|
||||
|
||||
# User 1 tries to access Tenant 2's documents
|
||||
response_cross = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session1.id,
|
||||
message="What is in Tenant 2's documents?",
|
||||
user_performing_action=admin_user1,
|
||||
user_performing_action=test_user1,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response_cross.tool_name == "run_search"
|
||||
@@ -147,7 +140,7 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
response_cross2 = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session2.id,
|
||||
message="What is in Tenant 1's documents?",
|
||||
user_performing_action=admin_user2,
|
||||
user_performing_action=test_user2,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response_cross2.tool_name == "run_search"
|
||||
|
||||
@@ -4,12 +4,14 @@ from onyx.db.models import UserRole
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.tenant import TenantManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
# Test flow from creating tenant to registering as a user
|
||||
def test_tenant_creation(reset_multitenant: None) -> None:
|
||||
TenantManager.create("tenant_dev", "test@test.com", "Data Plane Registration")
|
||||
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
|
||||
|
||||
assert UserManager.is_role(test_user, UserRole.ADMIN)
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.db.models import IndexingStatus
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestIndexAttempt
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _verify_index_attempt_pagination(
|
||||
cc_pair_id: int,
|
||||
index_attempt_ids: list[int],
|
||||
index_attempts: list[DATestIndexAttempt],
|
||||
page_size: int = 5,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
retrieved_attempts: list[int] = []
|
||||
last_time_started = None # Track the last time_started seen
|
||||
|
||||
for i in range(0, len(index_attempt_ids), page_size):
|
||||
for i in range(0, len(index_attempts), page_size):
|
||||
paginated_result = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id=cc_pair_id,
|
||||
page=(i // page_size),
|
||||
@@ -26,9 +26,9 @@ def _verify_index_attempt_pagination(
|
||||
)
|
||||
|
||||
# Verify that the total items is equal to the length of the index attempts list
|
||||
assert paginated_result.total_items == len(index_attempt_ids)
|
||||
assert paginated_result.total_items == len(index_attempts)
|
||||
# Verify that the number of items in the page is equal to the page size
|
||||
assert len(paginated_result.items) == min(page_size, len(index_attempt_ids) - i)
|
||||
assert len(paginated_result.items) == min(page_size, len(index_attempts) - i)
|
||||
|
||||
# Verify time ordering within the page (descending order)
|
||||
for attempt in paginated_result.items:
|
||||
@@ -42,7 +42,7 @@ def _verify_index_attempt_pagination(
|
||||
retrieved_attempts.extend([attempt.id for attempt in paginated_result.items])
|
||||
|
||||
# Create a set of all the expected index attempt IDs
|
||||
all_expected_attempts = set(index_attempt_ids)
|
||||
all_expected_attempts = set(attempt.id for attempt in index_attempts)
|
||||
# Create a set of all the retrieved index attempt IDs
|
||||
all_retrieved_attempts = set(retrieved_attempts)
|
||||
|
||||
@@ -51,9 +51,6 @@ def _verify_index_attempt_pagination(
|
||||
|
||||
|
||||
def test_index_attempt_pagination(reset: None) -> None:
|
||||
MAX_WAIT = 60
|
||||
all_attempt_ids: list[int] = []
|
||||
|
||||
# Create an admin user to perform actions
|
||||
user_performing_action: DATestUser = UserManager.create(
|
||||
name="admin_performing_action",
|
||||
@@ -65,49 +62,20 @@ def test_index_attempt_pagination(reset: None) -> None:
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
|
||||
# Creating a CC pair will create an index attempt as well. wait for it.
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
paginated_result = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id=cc_pair.id,
|
||||
page=0,
|
||||
page_size=5,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
|
||||
if paginated_result.total_items == 1:
|
||||
all_attempt_ids.append(paginated_result.items[0].id)
|
||||
print("Initial index attempt from cc_pair creation detected. Continuing...")
|
||||
break
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed > MAX_WAIT:
|
||||
raise TimeoutError(
|
||||
f"Initial index attempt: Not detected within {MAX_WAIT} seconds."
|
||||
)
|
||||
|
||||
print(
|
||||
f"Waiting for initial index attempt: elapsed={elapsed:.2f} timeout={MAX_WAIT}"
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
# Create 299 successful index attempts (for 300 total)
|
||||
# Create 300 successful index attempts
|
||||
base_time = datetime.now()
|
||||
generated_attempts = IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=299,
|
||||
all_attempts = IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=300,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=IndexingStatus.SUCCESS,
|
||||
base_time=base_time,
|
||||
)
|
||||
|
||||
for attempt in generated_attempts:
|
||||
all_attempt_ids.append(attempt.id)
|
||||
|
||||
# Verify basic pagination with different page sizes
|
||||
print("Verifying basic pagination with page size 5")
|
||||
_verify_index_attempt_pagination(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_ids=all_attempt_ids,
|
||||
index_attempts=all_attempts,
|
||||
page_size=5,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
@@ -116,7 +84,7 @@ def test_index_attempt_pagination(reset: None) -> None:
|
||||
print("Verifying pagination with page size 100")
|
||||
_verify_index_attempt_pagination(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_ids=all_attempt_ids,
|
||||
index_attempts=all_attempts,
|
||||
page_size=100,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
|
||||
@@ -1,423 +0,0 @@
|
||||
services:
|
||||
api_server:
|
||||
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile
|
||||
command: >
|
||||
/bin/sh -c "
|
||||
alembic -n schema_private upgrade head &&
|
||||
echo \"Starting Onyx Api Server\" &&
|
||||
uvicorn onyx.main:app --host 0.0.0.0 --port 8080"
|
||||
depends_on:
|
||||
- relational_db
|
||||
- index
|
||||
- cache
|
||||
- inference_model_server
|
||||
restart: always
|
||||
ports:
|
||||
- "8080:8080"
|
||||
environment:
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
- MULTI_TENANT=true
|
||||
- LOG_LEVEL=DEBUG
|
||||
- AUTH_TYPE=cloud
|
||||
- REQUIRE_EMAIL_VERIFICATION=false
|
||||
- DISABLE_TELEMETRY=true
|
||||
- IMAGE_TAG=test
|
||||
- DEV_MODE=true
|
||||
# Auth Settings
|
||||
- SESSION_EXPIRE_TIME_SECONDS=${SESSION_EXPIRE_TIME_SECONDS:-}
|
||||
- ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-}
|
||||
- VALID_EMAIL_DOMAINS=${VALID_EMAIL_DOMAINS:-}
|
||||
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
|
||||
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
|
||||
- SMTP_SERVER=${SMTP_SERVER:-}
|
||||
- SMTP_PORT=${SMTP_PORT:-587}
|
||||
- SMTP_USER=${SMTP_USER:-}
|
||||
- SMTP_PASS=${SMTP_PASS:-}
|
||||
- ENABLE_EMAIL_INVITES=${ENABLE_EMAIL_INVITES:-}
|
||||
- EMAIL_FROM=${EMAIL_FROM:-}
|
||||
- OAUTH_CLIENT_ID=${OAUTH_CLIENT_ID:-}
|
||||
- OAUTH_CLIENT_SECRET=${OAUTH_CLIENT_SECRET:-}
|
||||
- OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-}
|
||||
- TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-}
|
||||
- CORS_ALLOWED_ORIGIN=${CORS_ALLOWED_ORIGIN:-}
|
||||
# Gen AI Settings
|
||||
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
|
||||
- QA_TIMEOUT=${QA_TIMEOUT:-}
|
||||
- MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-}
|
||||
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
|
||||
- DISABLE_LLM_QUERY_REPHRASE=${DISABLE_LLM_QUERY_REPHRASE:-}
|
||||
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
||||
- DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-}
|
||||
- LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-}
|
||||
- BING_API_KEY=${BING_API_KEY:-}
|
||||
- DISABLE_LLM_DOC_RELEVANCE=${DISABLE_LLM_DOC_RELEVANCE:-}
|
||||
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
||||
- TOKEN_BUDGET_GLOBALLY_ENABLED=${TOKEN_BUDGET_GLOBALLY_ENABLED:-}
|
||||
# Query Options
|
||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-}
|
||||
- HYBRID_ALPHA=${HYBRID_ALPHA:-}
|
||||
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
|
||||
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
|
||||
- LANGUAGE_HINT=${LANGUAGE_HINT:-}
|
||||
- LANGUAGE_CHAT_NAMING_HINT=${LANGUAGE_CHAT_NAMING_HINT:-}
|
||||
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
|
||||
# Other services
|
||||
- POSTGRES_HOST=relational_db
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-}
|
||||
# Don't change the NLP model configs unless you know what you're doing
|
||||
- EMBEDDING_BATCH_SIZE=${EMBEDDING_BATCH_SIZE:-}
|
||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||
- DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-}
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
|
||||
- DISABLE_RERANK_FOR_STREAMING=${DISABLE_RERANK_FOR_STREAMING:-}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||
- LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-}
|
||||
- LOG_DANSWER_MODEL_INTERACTIONS=${LOG_DANSWER_MODEL_INTERACTIONS:-}
|
||||
- LOG_INDIVIDUAL_MODEL_TOKENS=${LOG_INDIVIDUAL_MODEL_TOKENS:-}
|
||||
- LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-}
|
||||
- LOG_ENDPOINT_LATENCY=${LOG_ENDPOINT_LATENCY:-}
|
||||
- LOG_POSTGRES_LATENCY=${LOG_POSTGRES_LATENCY:-}
|
||||
- LOG_POSTGRES_CONN_COUNTS=${LOG_POSTGRES_CONN_COUNTS:-}
|
||||
- CELERY_BROKER_POOL_LIMIT=${CELERY_BROKER_POOL_LIMIT:-}
|
||||
- LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS=${LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS:-}
|
||||
# Egnyte OAuth Configs
|
||||
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
|
||||
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
|
||||
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
|
||||
# Linear OAuth Configs
|
||||
- LINEAR_CLIENT_ID=${LINEAR_CLIENT_ID:-}
|
||||
- LINEAR_CLIENT_SECRET=${LINEAR_CLIENT_SECRET:-}
|
||||
# Analytics Configs
|
||||
- SENTRY_DSN=${SENTRY_DSN:-}
|
||||
# Chat Configs
|
||||
- HARD_DELETE_CHATS=${HARD_DELETE_CHATS:-}
|
||||
# Enables the use of bedrock models or IAM Auth
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME:-}
|
||||
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
||||
# Seeding configuration
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
background:
|
||||
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile
|
||||
command: >
|
||||
/bin/sh -c "
|
||||
if [ -f /etc/ssl/certs/custom-ca.crt ]; then
|
||||
update-ca-certificates;
|
||||
fi &&
|
||||
/usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf"
|
||||
depends_on:
|
||||
- relational_db
|
||||
- index
|
||||
- cache
|
||||
- inference_model_server
|
||||
- indexing_model_server
|
||||
restart: always
|
||||
environment:
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
- MULTI_TENANT=true
|
||||
- LOG_LEVEL=DEBUG
|
||||
- AUTH_TYPE=cloud
|
||||
- REQUIRE_EMAIL_VERIFICATION=false
|
||||
- DISABLE_TELEMETRY=true
|
||||
- IMAGE_TAG=test
|
||||
- ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-}
|
||||
- JWT_PUBLIC_KEY_URL=${JWT_PUBLIC_KEY_URL:-}
|
||||
# Gen AI Settings (Needed by OnyxBot)
|
||||
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
|
||||
- QA_TIMEOUT=${QA_TIMEOUT:-}
|
||||
- MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-}
|
||||
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
|
||||
- DISABLE_LLM_QUERY_REPHRASE=${DISABLE_LLM_QUERY_REPHRASE:-}
|
||||
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
||||
- GENERATIVE_MODEL_ACCESS_CHECK_FREQ=${GENERATIVE_MODEL_ACCESS_CHECK_FREQ:-}
|
||||
- DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-}
|
||||
- LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-}
|
||||
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
||||
- BING_API_KEY=${BING_API_KEY:-}
|
||||
# Query Options
|
||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-}
|
||||
- HYBRID_ALPHA=${HYBRID_ALPHA:-}
|
||||
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
|
||||
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
|
||||
- LANGUAGE_HINT=${LANGUAGE_HINT:-}
|
||||
- LANGUAGE_CHAT_NAMING_HINT=${LANGUAGE_CHAT_NAMING_HINT:-}
|
||||
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
|
||||
# Other Services
|
||||
- POSTGRES_HOST=relational_db
|
||||
- POSTGRES_USER=${POSTGRES_USER:-}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-}
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-}
|
||||
# Don't change the NLP model configs unless you know what you're doing
|
||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||
- DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-}
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
|
||||
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
||||
# Indexing Configs
|
||||
- VESPA_SEARCHER_THREADS=${VESPA_SEARCHER_THREADS:-}
|
||||
- NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-}
|
||||
- ENABLED_CONNECTOR_TYPES=${ENABLED_CONNECTOR_TYPES:-}
|
||||
- DISABLE_INDEX_UPDATE_ON_SWAP=${DISABLE_INDEX_UPDATE_ON_SWAP:-}
|
||||
- DASK_JOB_CLIENT_ENABLED=${DASK_JOB_CLIENT_ENABLED:-}
|
||||
- CONTINUE_ON_CONNECTOR_FAILURE=${CONTINUE_ON_CONNECTOR_FAILURE:-}
|
||||
- EXPERIMENTAL_CHECKPOINTING_ENABLED=${EXPERIMENTAL_CHECKPOINTING_ENABLED:-}
|
||||
- CONFLUENCE_CONNECTOR_LABELS_TO_SKIP=${CONFLUENCE_CONNECTOR_LABELS_TO_SKIP:-}
|
||||
- JIRA_CONNECTOR_LABELS_TO_SKIP=${JIRA_CONNECTOR_LABELS_TO_SKIP:-}
|
||||
- WEB_CONNECTOR_VALIDATE_URLS=${WEB_CONNECTOR_VALIDATE_URLS:-}
|
||||
- JIRA_API_VERSION=${JIRA_API_VERSION:-}
|
||||
- GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-}
|
||||
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
|
||||
- GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-}
|
||||
- MAX_DOCUMENT_CHARS=${MAX_DOCUMENT_CHARS:-}
|
||||
- MAX_FILE_SIZE_BYTES=${MAX_FILE_SIZE_BYTES:-}
|
||||
# Egnyte OAuth Configs
|
||||
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
|
||||
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
|
||||
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
|
||||
# Lienar OAuth Configs
|
||||
- LINEAR_CLIENT_ID=${LINEAR_CLIENT_ID:-}
|
||||
- LINEAR_CLIENT_SECRET=${LINEAR_CLIENT_SECRET:-}
|
||||
# Celery Configs (defaults are set in the supervisord.conf file.
|
||||
# prefer doing that to have one source of defaults)
|
||||
- CELERY_WORKER_INDEXING_CONCURRENCY=${CELERY_WORKER_INDEXING_CONCURRENCY:-}
|
||||
- CELERY_WORKER_LIGHT_CONCURRENCY=${CELERY_WORKER_LIGHT_CONCURRENCY:-}
|
||||
- CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER=${CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER:-}
|
||||
|
||||
# Onyx SlackBot Configs
|
||||
- DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER=${DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER:-}
|
||||
- DANSWER_BOT_FEEDBACK_VISIBILITY=${DANSWER_BOT_FEEDBACK_VISIBILITY:-}
|
||||
- DANSWER_BOT_DISPLAY_ERROR_MSGS=${DANSWER_BOT_DISPLAY_ERROR_MSGS:-}
|
||||
- DANSWER_BOT_RESPOND_EVERY_CHANNEL=${DANSWER_BOT_RESPOND_EVERY_CHANNEL:-}
|
||||
- DANSWER_BOT_DISABLE_COT=${DANSWER_BOT_DISABLE_COT:-} # Currently unused
|
||||
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
|
||||
- DANSWER_BOT_MAX_QPM=${DANSWER_BOT_MAX_QPM:-}
|
||||
- DANSWER_BOT_MAX_WAIT_TIME=${DANSWER_BOT_MAX_WAIT_TIME:-}
|
||||
# Logging
|
||||
# Leave this on pretty please? Nothing sensitive is collected!
|
||||
# https://docs.onyx.app/more/telemetry
|
||||
- DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-}
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info} # Set to debug to get more fine-grained logs
|
||||
- LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-} # LiteLLM Verbose Logging
|
||||
# Log all of Onyx prompts and interactions with the LLM
|
||||
- LOG_DANSWER_MODEL_INTERACTIONS=${LOG_DANSWER_MODEL_INTERACTIONS:-}
|
||||
- LOG_INDIVIDUAL_MODEL_TOKENS=${LOG_INDIVIDUAL_MODEL_TOKENS:-}
|
||||
- LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-}
|
||||
|
||||
# Analytics Configs
|
||||
- SENTRY_DSN=${SENTRY_DSN:-}
|
||||
|
||||
# Enterprise Edition stuff
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME:-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# Uncomment the following lines if you need to include a custom CA certificate
|
||||
# This section enables the use of a custom CA certificate
|
||||
# If present, the custom CA certificate is mounted as a volume
|
||||
# The container checks for its existence and updates the system's CA certificates
|
||||
# This allows for secure communication with services using custom SSL certificates
|
||||
# Optional volume mount for CA certificate
|
||||
# volumes:
|
||||
# # Maps to the CA_CERT_PATH environment variable in the Dockerfile
|
||||
# - ${CA_CERT_PATH:-./custom-ca.crt}:/etc/ssl/certs/custom-ca.crt:ro
|
||||
|
||||
web_server:
|
||||
image: onyxdotapp/onyx-web-server:${IMAGE_TAG:-latest}
|
||||
build:
|
||||
context: ../../web
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
- NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING:-false}
|
||||
- NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA=${NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA:-false}
|
||||
- NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
|
||||
- NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN:-}
|
||||
- NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED:-}
|
||||
# Enterprise Edition only
|
||||
- NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-}
|
||||
# DO NOT TURN ON unless you have EXPLICIT PERMISSION from Onyx.
|
||||
- NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED:-false}
|
||||
depends_on:
|
||||
- api_server
|
||||
restart: always
|
||||
environment:
|
||||
- INTERNAL_URL=http://api_server:8080
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-}
|
||||
- THEME_IS_DARK=${THEME_IS_DARK:-}
|
||||
- DISABLE_LLM_DOC_RELEVANCE=${DISABLE_LLM_DOC_RELEVANCE:-}
|
||||
|
||||
# Enterprise Edition only
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
- NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL:-}
|
||||
|
||||
inference_model_server:
|
||||
image: onyxdotapp/onyx-model-server:${IMAGE_TAG:-latest}
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
command: >
|
||||
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||
echo 'Skipping service...';
|
||||
exit 0;
|
||||
else
|
||||
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||
fi"
|
||||
restart: on-failure
|
||||
environment:
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
|
||||
# Analytics Configs
|
||||
- SENTRY_DSN=${SENTRY_DSN:-}
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
indexing_model_server:
|
||||
image: onyxdotapp/onyx-model-server:${IMAGE_TAG:-latest}
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
command: >
|
||||
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||
echo 'Skipping service...';
|
||||
exit 0;
|
||||
else
|
||||
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||
fi"
|
||||
restart: on-failure
|
||||
environment:
|
||||
- INDEX_BATCH_SIZE=${INDEX_BATCH_SIZE:-}
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
- INDEXING_ONLY=True
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
- CLIENT_EMBEDDING_TIMEOUT=${CLIENT_EMBEDDING_TIMEOUT:-}
|
||||
|
||||
# Analytics Configs
|
||||
- SENTRY_DSN=${SENTRY_DSN:-}
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- indexing_huggingface_model_cache:/root/.cache/huggingface/
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
relational_db:
|
||||
image: postgres:15.2-alpine
|
||||
command: -c 'max_connections=250'
|
||||
restart: always
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
# This container name cannot have an underscore in it due to Vespa expectations of the URL
|
||||
index:
|
||||
image: vespaengine/vespa:8.277.17
|
||||
restart: always
|
||||
ports:
|
||||
- "19071:19071"
|
||||
- "8081:8081"
|
||||
volumes:
|
||||
- vespa_volume:/opt/vespa/var
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
nginx:
|
||||
image: nginx:1.23.4-alpine
|
||||
restart: always
|
||||
# nginx will immediately crash with `nginx: [emerg] host not found in upstream`
|
||||
# if api_server / web_server are not up
|
||||
depends_on:
|
||||
- api_server
|
||||
- web_server
|
||||
environment:
|
||||
- DOMAIN=localhost
|
||||
ports:
|
||||
- "80:80"
|
||||
- "3000:80" # allow for localhost:3000 usage, since that is the norm
|
||||
volumes:
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev"
|
||||
|
||||
cache:
|
||||
image: redis:7.4-alpine
|
||||
restart: always
|
||||
ports:
|
||||
- "6379:6379"
|
||||
# docker silently mounts /data even without an explicit volume mount, which enables
|
||||
# persistence. explicitly setting save and appendonly forces ephemeral behavior.
|
||||
command: redis-server --save "" --appendonly no
|
||||
|
||||
volumes:
|
||||
db_volume:
|
||||
vespa_volume: # Created by the container itself
|
||||
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
@@ -18,11 +18,7 @@ import AdvancedFormPage from "./pages/Advanced";
|
||||
import DynamicConnectionForm from "./pages/DynamicConnectorCreationForm";
|
||||
import CreateCredential from "@/components/credentials/actions/CreateCredential";
|
||||
import ModifyCredential from "@/components/credentials/actions/ModifyCredential";
|
||||
import {
|
||||
ConfigurableSources,
|
||||
oauthSupportedSources,
|
||||
ValidSources,
|
||||
} from "@/lib/types";
|
||||
import { ConfigurableSources, oauthSupportedSources } from "@/lib/types";
|
||||
import {
|
||||
Credential,
|
||||
credentialTemplates,
|
||||
@@ -448,7 +444,7 @@ export default function AddConnector({
|
||||
<CardSection>
|
||||
<Title className="mb-2 text-lg">Select a credential</Title>
|
||||
|
||||
{connector == ValidSources.Gmail ? (
|
||||
{connector == "gmail" ? (
|
||||
<GmailMain />
|
||||
) : (
|
||||
<>
|
||||
|
||||
@@ -949,15 +949,15 @@ export function ChatPage({
|
||||
// Check if all messages are currently rendered
|
||||
if (currentVisibleRange.end < messageHistory.length) {
|
||||
// Update visible range to include the last messages
|
||||
updateCurrentVisibleRange({
|
||||
start: Math.max(
|
||||
0,
|
||||
messageHistory.length -
|
||||
(currentVisibleRange.end - currentVisibleRange.start)
|
||||
),
|
||||
end: messageHistory.length,
|
||||
mostVisibleMessageId: currentVisibleRange.mostVisibleMessageId,
|
||||
});
|
||||
// updateCurrentVisibleRange({
|
||||
// start: Math.max(
|
||||
// 0,
|
||||
// messageHistory.length -
|
||||
// (currentVisibleRange.end - currentVisibleRange.start)
|
||||
// ),
|
||||
// end: messageHistory.length,
|
||||
// mostVisibleMessageId: currentVisibleRange.mostVisibleMessageId,
|
||||
// });
|
||||
|
||||
// Wait for the state update and re-render before scrolling
|
||||
setTimeout(() => {
|
||||
@@ -1121,7 +1121,6 @@ export function ChatPage({
|
||||
"Continue Generating (pick up exactly where you left off)",
|
||||
});
|
||||
};
|
||||
const [uncaughtError, setUncaughtError] = useState<string | null>(null);
|
||||
|
||||
const onSubmit = async ({
|
||||
messageIdToResend,
|
||||
@@ -1550,23 +1549,8 @@ export function ChatPage({
|
||||
}
|
||||
);
|
||||
} else if (Object.hasOwn(packet, "error")) {
|
||||
if (
|
||||
sub_questions.length > 0 &&
|
||||
sub_questions
|
||||
.filter((q) => q.level === 0)
|
||||
.every((q) => q.is_stopped === true)
|
||||
) {
|
||||
setUncaughtError((packet as StreamingError).error);
|
||||
updateChatState("input");
|
||||
setAgenticGenerating(false);
|
||||
setAlternativeGeneratingAssistant(null);
|
||||
setSubmittedMessage("");
|
||||
return;
|
||||
// throw new Error((packet as StreamingError).error);
|
||||
} else {
|
||||
error = (packet as StreamingError).error;
|
||||
stackTrace = (packet as StreamingError).stack_trace;
|
||||
}
|
||||
error = (packet as StreamingError).error;
|
||||
stackTrace = (packet as StreamingError).stack_trace;
|
||||
} else if (Object.hasOwn(packet, "message_id")) {
|
||||
finalMessage = packet as BackendMessage;
|
||||
} else if (Object.hasOwn(packet, "stop_reason")) {
|
||||
@@ -1885,6 +1869,7 @@ export function ChatPage({
|
||||
newRange: VisibleRange,
|
||||
forceUpdate?: boolean
|
||||
) => {
|
||||
console.log("updateCurrentVisibleRange", newRange);
|
||||
if (
|
||||
scrollInitialized.current &&
|
||||
visibleRange.get(loadedIdSessionRef.current) == undefined &&
|
||||
@@ -1923,26 +1908,54 @@ export function ChatPage({
|
||||
scrollInitialized.current = true;
|
||||
}
|
||||
};
|
||||
const setVisibleRangeForCurrentSessionId = (newRange: VisibleRange) => {
|
||||
console.log("setVisibleRangeForCurrentSessionId", newRange);
|
||||
setVisibleRange((prevState) => {
|
||||
const newState = new Map(prevState);
|
||||
newState.set(currentSessionId(), newRange);
|
||||
return newState;
|
||||
});
|
||||
};
|
||||
|
||||
const updateVisibleRangeBasedOnScroll = () => {
|
||||
if (!scrollInitialized.current) return;
|
||||
function updateVisibleRangeBasedOnScroll() {
|
||||
const scrollableDiv = scrollableDivRef.current;
|
||||
if (!scrollableDiv) return;
|
||||
|
||||
const viewportHeight = scrollableDiv.clientHeight;
|
||||
let mostVisibleMessageIndex = -1;
|
||||
const distanceFromBottom =
|
||||
scrollableDiv.scrollHeight -
|
||||
scrollableDiv.scrollTop -
|
||||
scrollableDiv.clientHeight;
|
||||
|
||||
const isNearBottom = distanceFromBottom < 200;
|
||||
|
||||
// If user is near bottom, we treat the last message as "most visible"
|
||||
if (isNearBottom) {
|
||||
const startIndex = Math.max(0, messageHistory.length - BUFFER_COUNT);
|
||||
const endIndex = messageHistory.length;
|
||||
setVisibleRangeForCurrentSessionId({
|
||||
start: startIndex,
|
||||
end: endIndex,
|
||||
mostVisibleMessageId:
|
||||
messageHistory.length > 0
|
||||
? messageHistory[messageHistory.length - 1].messageId
|
||||
: null,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// otherwise do the bounding rect logic:
|
||||
let mostVisibleMessageIndex = -1;
|
||||
const viewportHeight = scrollableDiv.clientHeight;
|
||||
messageHistory.forEach((message, index) => {
|
||||
const messageElement = document.getElementById(
|
||||
`message-${message.messageId}`
|
||||
);
|
||||
if (messageElement) {
|
||||
const rect = messageElement.getBoundingClientRect();
|
||||
const isVisible = rect.bottom <= viewportHeight && rect.bottom > 0;
|
||||
const elem = document.getElementById(`message-${message.messageId}`);
|
||||
if (elem) {
|
||||
const rect = elem.getBoundingClientRect();
|
||||
const isVisible = rect.top < viewportHeight && rect.bottom >= 0;
|
||||
if (isVisible && index > mostVisibleMessageIndex) {
|
||||
mostVisibleMessageIndex = index;
|
||||
}
|
||||
}
|
||||
// clientScrollToBottom;
|
||||
});
|
||||
|
||||
if (mostVisibleMessageIndex !== -1) {
|
||||
@@ -1951,34 +1964,50 @@ export function ChatPage({
|
||||
messageHistory.length,
|
||||
mostVisibleMessageIndex + BUFFER_COUNT + 1
|
||||
);
|
||||
|
||||
updateCurrentVisibleRange({
|
||||
setVisibleRangeForCurrentSessionId({
|
||||
start: startIndex,
|
||||
end: endIndex,
|
||||
mostVisibleMessageId: messageHistory[mostVisibleMessageIndex].messageId,
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
initializeVisibleRange();
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [router, messageHistory]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
const scrollableDiv = scrollableDivRef.current;
|
||||
useEffect(() => {
|
||||
console.log("useEffect has been called");
|
||||
const scrollEl = scrollableDivRef.current;
|
||||
if (!scrollEl) {
|
||||
console.log("no scrollEl");
|
||||
return;
|
||||
}
|
||||
|
||||
const handleScroll = () => {
|
||||
updateVisibleRangeBasedOnScroll();
|
||||
requestAnimationFrame(() => {
|
||||
updateVisibleRangeBasedOnScroll();
|
||||
});
|
||||
};
|
||||
|
||||
scrollableDiv?.addEventListener("scroll", handleScroll);
|
||||
const attachScrollListener = () => {
|
||||
if (scrollEl) {
|
||||
scrollEl.addEventListener("scroll", handleScroll);
|
||||
} else {
|
||||
console.log("scrollEl not available, retrying in 100ms");
|
||||
setTimeout(attachScrollListener, 100);
|
||||
}
|
||||
};
|
||||
|
||||
attachScrollListener();
|
||||
|
||||
return () => {
|
||||
scrollableDiv?.removeEventListener("scroll", handleScroll);
|
||||
if (scrollEl) {
|
||||
scrollEl.removeEventListener("scroll", handleScroll);
|
||||
}
|
||||
};
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [messageHistory]);
|
||||
}, [scrollableDivRef, messageHistory, currentSessionId()]);
|
||||
|
||||
const imageFileInMessageHistory = useMemo(() => {
|
||||
return messageHistory
|
||||
@@ -2055,7 +2084,6 @@ export function ChatPage({
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
router.push(data.redirect_url);
|
||||
} catch (error) {
|
||||
console.error("Error seeding chat from Slack:", error);
|
||||
@@ -2510,7 +2538,7 @@ export function ChatPage({
|
||||
? messageHistory
|
||||
: messageHistory.slice(
|
||||
currentVisibleRange.start,
|
||||
currentVisibleRange.end
|
||||
currentVisibleRange.end + 2
|
||||
)
|
||||
).map((message, fauxIndex) => {
|
||||
const i =
|
||||
@@ -2650,7 +2678,6 @@ export function ChatPage({
|
||||
{message.sub_questions &&
|
||||
message.sub_questions.length > 0 ? (
|
||||
<AgenticMessage
|
||||
error={uncaughtError}
|
||||
docSidebarToggled={
|
||||
documentSidebarToggled &&
|
||||
(selectedMessageForDocDisplay ==
|
||||
|
||||
@@ -80,7 +80,6 @@ export const AgenticMessage = ({
|
||||
agenticDocs,
|
||||
secondLevelSubquestions,
|
||||
toggleDocDisplay,
|
||||
error,
|
||||
}: {
|
||||
docSidebarToggled?: boolean;
|
||||
isImprovement?: boolean | null;
|
||||
@@ -111,7 +110,6 @@ export const AgenticMessage = ({
|
||||
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
|
||||
setPresentingDocument?: (document: OnyxDocument) => void;
|
||||
toggleDocDisplay?: (agentic: boolean) => void;
|
||||
error?: string | null;
|
||||
}) => {
|
||||
const [noShowingMessage, setNoShowingMessage] = useState(isComplete);
|
||||
|
||||
@@ -485,28 +483,11 @@ export const AgenticMessage = ({
|
||||
) : (
|
||||
content
|
||||
)}
|
||||
{error && (
|
||||
<p className="mt-2 text-red-700 text-sm my-auto">
|
||||
{error}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
) : isComplete ? (
|
||||
error && (
|
||||
<p className="mt-2 mx-4 text-red-700 text-sm my-auto">
|
||||
{error}
|
||||
</p>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
{error && (
|
||||
<p className="mt-2 mx-4 text-red-700 text-sm my-auto">
|
||||
{error}
|
||||
</p>
|
||||
)}
|
||||
</>
|
||||
) : isComplete ? null : (
|
||||
<></>
|
||||
)}
|
||||
{handleFeedback &&
|
||||
(isActive ? (
|
||||
|
||||
@@ -185,7 +185,6 @@ export const AIMessage = ({
|
||||
setPresentingDocument,
|
||||
index,
|
||||
toggledDocumentSidebar,
|
||||
removePadding,
|
||||
}: {
|
||||
index?: number;
|
||||
shared?: boolean;
|
||||
@@ -214,7 +213,6 @@ export const AIMessage = ({
|
||||
overriddenModel?: string;
|
||||
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
|
||||
setPresentingDocument?: (document: OnyxDocument) => void;
|
||||
removePadding?: boolean;
|
||||
}) => {
|
||||
const toolCallGenerating = toolCall && !toolCall.tool_result;
|
||||
|
||||
@@ -400,9 +398,7 @@ export const AIMessage = ({
|
||||
<div
|
||||
id={isComplete ? "onyx-ai-message" : undefined}
|
||||
ref={trackedElementRef}
|
||||
className={`py-5 ml-4 lg:px-5 relative flex
|
||||
|
||||
${removePadding && "!pl-24 -mt-12"}`}
|
||||
className={`py-5 ml-4 lg:px-5 relative flex `}
|
||||
>
|
||||
<div
|
||||
className={`mx-auto ${
|
||||
@@ -411,13 +407,11 @@ export const AIMessage = ({
|
||||
>
|
||||
<div className={`lg:mr-12 ${!shared && "mobile:ml-0 md:ml-8"}`}>
|
||||
<div className="flex">
|
||||
{!removePadding && (
|
||||
<AssistantIcon
|
||||
className="mobile:hidden"
|
||||
size={24}
|
||||
assistant={alternativeAssistant || currentPersona}
|
||||
/>
|
||||
)}
|
||||
<AssistantIcon
|
||||
className="mobile:hidden"
|
||||
size={24}
|
||||
assistant={alternativeAssistant || currentPersona}
|
||||
/>
|
||||
|
||||
<div className="w-full">
|
||||
<div className="max-w-message-max break-words">
|
||||
@@ -594,8 +588,7 @@ export const AIMessage = ({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{!removePadding &&
|
||||
handleFeedback &&
|
||||
{handleFeedback &&
|
||||
(isActive ? (
|
||||
<div
|
||||
className={`
|
||||
|
||||
@@ -772,7 +772,7 @@ For example, specifying .*-support.* as a "channel" will cause the connector to
|
||||
advanced_values: [],
|
||||
},
|
||||
linear: {
|
||||
description: "Configure Linear connector",
|
||||
description: "Configure Dropbox connector",
|
||||
values: [],
|
||||
advanced_values: [],
|
||||
},
|
||||
|
||||
@@ -114,7 +114,7 @@ export interface LoopioCredentialJson {
|
||||
}
|
||||
|
||||
export interface LinearCredentialJson {
|
||||
linear_access_token: string;
|
||||
linear_api_key: string;
|
||||
}
|
||||
|
||||
export interface HubSpotCredentialJson {
|
||||
@@ -250,7 +250,7 @@ export const credentialTemplates: Record<ValidSources, any> = {
|
||||
gong_access_key_secret: "",
|
||||
} as GongCredentialJson,
|
||||
zulip: { zuliprc_content: "" } as ZulipCredentialJson,
|
||||
linear: { linear_access_token: "" } as LinearCredentialJson,
|
||||
linear: { linear_api_key: "" } as LinearCredentialJson,
|
||||
hubspot: { hubspot_access_token: "" } as HubSpotCredentialJson,
|
||||
document360: {
|
||||
portal_id: "",
|
||||
@@ -404,7 +404,7 @@ export const credentialDisplayNames: Record<string, string> = {
|
||||
loopio_client_token: "Loopio Client Token",
|
||||
|
||||
// Linear
|
||||
linear_access_token: "Linear Access Token",
|
||||
linear_api_key: "Linear API Key",
|
||||
|
||||
// HubSpot
|
||||
hubspot_access_token: "HubSpot Access Token",
|
||||
|
||||
Reference in New Issue
Block a user