Compare commits

..

35 Commits

Author SHA1 Message Date
pablodanswer
547fefb306 nit 2025-02-13 15:10:29 -08:00
pablodanswer
2c36dd162d update 2025-02-13 14:55:53 -08:00
pablodanswer
e0f1ca974e quick nit 2025-02-13 14:55:53 -08:00
pablodanswer
737c6118a4 reduce errors in workers 2025-02-13 14:55:53 -08:00
Yuhong Sun
c87261cda7 Fix edge case with run functions in parallel 2025-02-12 17:57:39 -08:00
pablonyx
e030b0a6fc Address (#3955) 2025-02-12 13:53:13 -08:00
Yuhong Sun
61136975ad Don't build model server every night (#3973) 2025-02-12 13:08:05 -08:00
Weves
0c74bbf9ed Clean illegal chars in metadata 2025-02-12 11:49:16 -08:00
pablonyx
12b2126e69 Update assistants visibility, minor UX, .. (#3965)
* update assistant logic

* quick nit

* k

* fix "featured" logic

* Small tweaks

* k

---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
2025-02-12 00:43:20 +00:00
Chris Weaver
037943c6ff Support share/view IDs for Airtable (#3967) 2025-02-11 16:19:38 -08:00
pablonyx
f9485b1325 Ensure sidepanel defaults sidebar off (#3844)
* ensure sidepanel defaults sidepanel off

* address comment

* reformat

* initial visible
2025-02-11 22:22:56 +00:00
rkuo-danswer
552a0630fe Merge pull request #3948 from onyx-dot-app/feature/beat_rtvar
refactoring and update multiplier in real time
2025-02-11 14:05:14 -08:00
Richard Kuo (Danswer)
5bf520d8b8 comments 2025-02-11 14:04:49 -08:00
Weves
7dc5a77946 Improve starter message splitting 2025-02-11 11:10:13 -08:00
rkuo-danswer
03abd4a1bc Merge pull request #3938 from onyx-dot-app/feature/model_server_logs
improve gpu detection functions and logging in model server
2025-02-11 09:43:25 -08:00
Richard Kuo (Danswer)
16d6d708f6 update logging 2025-02-11 09:15:39 -08:00
Richard Kuo
9740ed32b5 fix reading redis values as floats 2025-02-10 20:48:55 -08:00
rkuo-danswer
b56877cc2e Bugfix/dedupe ids (#3952)
* dedupe make_private_persona and update test

* add comment

* comments, and just have duplicate user id's for the test instead of modifying edit

* found the magic word

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-11 02:27:55 +00:00
pablodanswer
da5c83a96d k 2025-02-10 17:45:00 -08:00
Weves
818225c60e Fix starter message overflow 2025-02-10 17:17:31 -08:00
Weves
d78a1fe9c6 Fix for red background 2025-02-10 16:36:26 -08:00
Weves
05b3e594b5 Increase timeout for reasoning models + make o1 available by default 2025-02-10 16:11:01 -08:00
Richard Kuo (Danswer)
5a4d007cf9 comments 2025-02-10 15:03:59 -08:00
pablonyx
3b25a2dd84 Ux improvements (#3947)
* black history sidebar

* misc improvements

* minor misc ux improvemnts

* quick nit

* add nits

* quick nit
2025-02-10 12:18:41 -08:00
pablonyx
baee4c5f22 Multi tenant specific error page (#3928)
Multi tenant specific error page
2025-02-10 11:51:29 -08:00
Richard Kuo (Danswer)
5e32f9d922 refactoring and update multiplier in real time 2025-02-10 11:20:38 -08:00
pablonyx
1454e7e07d New ux dark (#3944) 2025-02-09 21:14:32 -08:00
rkuo-danswer
6848337445 add validation for pruning/group sync etc (#3882)
* add validation for pruning

* fix missing class

* get external group sync validation working

* backport fix for pruning check

* fix pruning

* log the payload id

* remove scan_iter from pruning

* missed removed scan_iter, also remove other scan_iters and replace with sscan_iter of the lookup table

* external group sync needs active signal. h

* log the payload id when the task starts

* log the payload id in more places

* use the replica

* increase primary pool and slow down beat

* scale sql pool based on concurrency

* fix concurrency

* add debugging for external group sync and tenant

* remove debugging and fix payload id

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-10 03:12:21 +00:00
pablonyx
519fbd897e Add Dark Mode (#3936)
* k

* intermediate unification

* many changes

* update dark mode configs

* updates

* decent state

* functional

* mostly clean

* updaet model selector

* finalize

* calendar update

* additional styling

* nit

* k

* update colors

* push change

* k

* update

* k

* update

* address additions

* quick nit
2025-02-09 23:09:40 +00:00
evan-danswer
217569104b added context type for when internet search tool is used (#3930) 2025-02-08 20:44:38 -08:00
rkuo-danswer
4c184bb7f0 Bugfix/slack stop 2 (#3916)
* use callback in slim doc functions

* more callbacks

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-08 23:45:41 +00:00
rkuo-danswer
a222fae7c8 Bugfix/beat templates (#3754)
* WIP

* migrate most beat tasks to fan out strategy

* fix kwargs

* migrate EE tasks

* lock on the task_name level

* typo fix

* transform beat tasks for cloud

* cloud multiplier is only for cloud tasks

* bumpity

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-08 06:57:57 +00:00
pablonyx
94788cda53 Update display (#3934)
* update display

* quick nit
2025-02-08 02:07:47 +00:00
Richard Kuo (Danswer)
fb931ee4de fixes 2025-02-07 17:28:17 -08:00
Richard Kuo (Danswer)
bc2c56dfb6 improve gpu detection functions and logging in model server 2025-02-07 16:59:02 -08:00
330 changed files with 5297 additions and 3938 deletions

View File

@@ -4,6 +4,9 @@ on:
push:
tags:
- "*"
paths:
- 'backend/model_server/**'
- 'backend/Dockerfile.model_server'
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}

View File

@@ -133,3 +133,4 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=onyx-dot-app/onyx&type=Date)](https://star-history.com/#onyx-dot-app/onyx&Date)

View File

@@ -0,0 +1,32 @@
"""set built in to default
Revision ID: 2cdeff6d8c93
Revises: f5437cc136c5
Create Date: 2025-02-11 14:57:51.308775
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "2cdeff6d8c93"
down_revision = "f5437cc136c5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Prior to this migration / point in the codebase history,
# built in personas were implicitly treated as default personas (with no option to change this)
# This migration makes that explicit
op.execute(
"""
UPDATE persona
SET is_default_persona = TRUE
WHERE builtin_persona = TRUE
"""
)
def downgrade() -> None:
pass

View File

@@ -3,42 +3,44 @@ from typing import Any
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.beat_schedule import (
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
beat_system_tasks as base_beat_system_tasks,
)
from onyx.background.celery.tasks.beat_schedule import (
tasks_to_schedule as base_tasks_to_schedule,
beat_task_templates as base_beat_task_templates,
)
from onyx.background.celery.tasks.beat_schedule import generate_cloud_tasks
from onyx.background.celery.tasks.beat_schedule import (
get_tasks_to_schedule as base_get_tasks_to_schedule,
)
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from shared_configs.configs import MULTI_TENANT
ee_cloud_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_autogenerate-usage-report",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
ee_beat_system_tasks: list[dict] = []
ee_beat_task_templates: list[dict] = []
ee_beat_task_templates.extend(
[
{
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
"kwargs": {
"task_name": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
{
"name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-ttl-management",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
},
},
]
]
)
ee_tasks_to_schedule: list[dict] = []
@@ -65,9 +67,14 @@ if not MULTI_TENANT:
]
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
return ee_cloud_tasks_to_schedule + base_cloud_tasks_to_schedule
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
beat_system_tasks = ee_beat_system_tasks + base_beat_system_tasks
beat_task_templates = ee_beat_task_templates + base_beat_task_templates
cloud_tasks = generate_cloud_tasks(
beat_system_tasks, beat_task_templates, beat_multiplier
)
return cloud_tasks
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return ee_tasks_to_schedule + base_tasks_to_schedule
return ee_tasks_to_schedule + base_get_tasks_to_schedule()

View File

@@ -15,6 +15,9 @@ def make_persona_private(
group_ids: list[int] | None,
db_session: Session,
) -> None:
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
dedupe the inputs, the commit will exception."""
db_session.query(Persona__User).filter(
Persona__User.persona_id == persona_id
).delete(synchronize_session="fetch")
@@ -23,19 +26,22 @@ def make_persona_private(
).delete(synchronize_session="fetch")
if user_ids:
for user_uuid in user_ids:
db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid))
user_ids_set = set(user_ids)
for user_id in user_ids_set:
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
create_notification(
user_id=user_uuid,
user_id=user_id,
notif_type=NotificationType.PERSONA_SHARED,
db_session=db_session,
additional_data=PersonaSharedNotificationData(
persona_id=persona_id,
).model_dump(),
)
if group_ids:
for group_id in group_ids:
group_ids_set = set(group_ids)
for group_id in group_ids_set:
db_session.add(
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
)

View File

@@ -365,7 +365,9 @@ def confluence_doc_sync(
slim_docs = []
logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents():
for doc_batch in confluence_connector.retrieve_all_slim_documents(
callback=callback
):
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
if callback:
if callback.should_stop():

View File

@@ -15,6 +15,7 @@ logger = setup_logger()
def _get_slim_doc_generator(
cc_pair: ConnectorCredentialPair,
gmail_connector: GmailConnector,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
current_time = datetime.now(timezone.utc)
start_time = (
@@ -24,7 +25,9 @@ def _get_slim_doc_generator(
)
return gmail_connector.retrieve_all_slim_documents(
start=start_time, end=current_time.timestamp()
start=start_time,
end=current_time.timestamp(),
callback=callback,
)
@@ -40,7 +43,9 @@ def gmail_doc_sync(
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
gmail_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = _get_slim_doc_generator(cc_pair, gmail_connector)
slim_doc_generator = _get_slim_doc_generator(
cc_pair, gmail_connector, callback=callback
)
document_external_access: list[DocExternalAccess] = []
for slim_doc_batch in slim_doc_generator:

View File

@@ -21,6 +21,7 @@ _PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {}
def _get_slim_doc_generator(
cc_pair: ConnectorCredentialPair,
google_drive_connector: GoogleDriveConnector,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
current_time = datetime.now(timezone.utc)
start_time = (
@@ -30,7 +31,9 @@ def _get_slim_doc_generator(
)
return google_drive_connector.retrieve_all_slim_documents(
start=start_time, end=current_time.timestamp()
start=start_time,
end=current_time.timestamp(),
callback=callback,
)

View File

@@ -20,19 +20,11 @@ def _get_slack_document_ids_and_channels(
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents()
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
channel_doc_map: dict[str, list[str]] = {}
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if callback:
if callback.should_stop():
raise RuntimeError(
"_get_slack_document_ids_and_channels: Stop signal detected"
)
callback.progress("_get_slack_document_ids_and_channels", 1)
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
@@ -40,6 +32,14 @@ def _get_slack_document_ids_and_channels(
channel_doc_map[channel_id] = []
channel_doc_map[channel_id].append(doc_metadata.id)
if callback:
if callback.should_stop():
raise RuntimeError(
"_get_slack_document_ids_and_channels: Stop signal detected"
)
callback.progress("_get_slack_document_ids_and_channels", 1)
return channel_doc_map

View File

@@ -28,3 +28,9 @@ class EmbeddingModelTextType:
@staticmethod
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
class GPUStatus:
CUDA = "cuda"
MAC_MPS = "mps"
NONE = "none"

View File

@@ -12,6 +12,7 @@ import voyageai # type: ignore
from cohere import AsyncClient as CohereAsyncClient
from fastapi import APIRouter
from fastapi import HTTPException
from fastapi import Request
from google.oauth2 import service_account # type: ignore
from litellm import aembedding
from litellm.exceptions import RateLimitError
@@ -320,6 +321,7 @@ async def embed_text(
prefix: str | None,
api_url: str | None,
api_version: str | None,
gpu_type: str = "UNKNOWN",
) -> list[Embedding]:
if not all(texts):
logger.error("Empty strings provided for embedding")
@@ -373,8 +375,11 @@ async def embed_text(
elapsed = time.monotonic() - start
logger.info(
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
f"with provider {provider_type} in {elapsed:.2f}"
f"event=embedding_provider "
f"texts={len(texts)} "
f"chars={total_chars} "
f"provider={provider_type} "
f"elapsed={elapsed:.2f}"
)
elif model_name is not None:
logger.info(
@@ -403,6 +408,14 @@ async def embed_text(
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
f"with local model {model_name} in {elapsed:.2f}"
)
logger.info(
f"event=embedding_model "
f"texts={len(texts)} "
f"chars={total_chars} "
f"model={model_name} "
f"gpu={gpu_type} "
f"elapsed={elapsed:.2f}"
)
else:
logger.error("Neither model name nor provider specified for embedding")
raise ValueError(
@@ -455,8 +468,15 @@ async def litellm_rerank(
@router.post("/bi-encoder-embed")
async def process_embed_request(
async def route_bi_encoder_embed(
request: Request,
embed_request: EmbedRequest,
) -> EmbedResponse:
return await process_embed_request(embed_request, request.app.state.gpu_type)
async def process_embed_request(
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
) -> EmbedResponse:
if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded")
@@ -484,6 +504,7 @@ async def process_embed_request(
api_url=embed_request.api_url,
api_version=embed_request.api_version,
prefix=prefix,
gpu_type=gpu_type,
)
return EmbedResponse(embeddings=embeddings)
except RateLimitError as e:

View File

@@ -16,6 +16,7 @@ from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.management_endpoints import router as management_router
from model_server.utils import get_gpu_type
from onyx import __version__
from onyx.utils.logger import setup_logger
from shared_configs.configs import INDEXING_ONLY
@@ -58,12 +59,10 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
if torch.cuda.is_available():
logger.notice("CUDA GPU is available")
elif torch.backends.mps.is_available():
logger.notice("Mac MPS is available")
else:
logger.notice("GPU is not available, using CPU")
gpu_type = get_gpu_type()
logger.notice(f"Torch GPU Detection: gpu_type={gpu_type}")
app.state.gpu_type = gpu_type
if TEMP_HF_CACHE_PATH.is_dir():
logger.notice("Moving contents of temp_huggingface to huggingface cache.")

View File

@@ -1,7 +1,9 @@
import torch
from fastapi import APIRouter
from fastapi import Response
from model_server.constants import GPUStatus
from model_server.utils import get_gpu_type
router = APIRouter(prefix="/api")
@@ -11,10 +13,7 @@ async def healthcheck() -> Response:
@router.get("/gpu-status")
async def gpu_status() -> dict[str, bool | str]:
if torch.cuda.is_available():
return {"gpu_available": True, "type": "cuda"}
elif torch.backends.mps.is_available():
return {"gpu_available": True, "type": "mps"}
else:
return {"gpu_available": False, "type": "none"}
async def route_gpu_status() -> dict[str, bool | str]:
gpu_type = get_gpu_type()
gpu_available = gpu_type != GPUStatus.NONE
return {"gpu_available": gpu_available, "type": gpu_type}

View File

@@ -8,6 +8,9 @@ from typing import Any
from typing import cast
from typing import TypeVar
import torch
from model_server.constants import GPUStatus
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -58,3 +61,12 @@ def simple_log_function_time(
return cast(F, wrapped_sync_func)
return decorator
def get_gpu_type() -> str:
if torch.cuda.is_available():
return GPUStatus.CUDA
if torch.backends.mps.is_available():
return GPUStatus.MAC_MPS
return GPUStatus.NONE

View File

@@ -9,6 +9,7 @@ class CoreState(BaseModel):
This is the core state that is shared across all subgraphs.
"""
base_question: str = ""
log_messages: Annotated[list[str], add] = []
@@ -17,4 +18,4 @@ class SubgraphCoreState(BaseModel):
This is the core state that is shared across all subgraphs.
"""
log_messages: Annotated[list[str], add] = []
log_messages: Annotated[list[str], add]

View File

@@ -1,8 +1,8 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
@@ -12,39 +12,12 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
SubQuestionAnswerCheckUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import AgentLLMErrorType
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. The sub-answer will be treated as 'relevant'",
rate_limit="LLM Rate Limit Error. The sub-answer will be treated as 'relevant'",
general_error="General LLM Error. The sub-answer will be treated as 'relevant'",
)
def check_sub_answer(
@@ -80,46 +53,14 @@ def check_sub_answer(
graph_config = cast(GraphConfig, config["metadata"]["config"])
fast_llm = graph_config.tooling.fast_llm
agent_error: AgentErrorLoggingFormat | None = None
response: BaseMessage | None = None
try:
response = fast_llm.invoke(
response = list(
fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
)
)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - check sub answer")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - check sub answer")
if agent_error:
answer_quality = True
log_result = agent_error.error_result
else:
if response:
quality_str: str = cast(str, response.content)
answer_quality = binary_string_test(
text=quality_str, positive_value=AGENT_POSITIVE_VALUE_STR
)
else:
answer_quality = True
quality_str = "yes - because LLM error"
log_result = f"Answer quality: {quality_str}"
quality_str: str = merge_message_runs(response, chunk_separator="")[0].content
answer_quality = "yes" in quality_str.lower()
return SubQuestionAnswerCheckUpdate(
answer_quality=answer_quality,
@@ -128,7 +69,7 @@ def check_sub_answer(
graph_component="initial - generate individual sub answer",
node_name="check sub answer",
node_start_time=node_start_time,
result=log_result,
result=f"Answer quality: {quality_str}",
)
],
)

View File

@@ -16,20 +16,6 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
LLM_ANSWER_ERROR_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
@@ -44,20 +30,11 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. A sub-answer could not be constructed and the sub-question will be ignored.",
rate_limit="LLM Rate Limit Error. A sub-answer could not be constructed and the sub-question will be ignored.",
general_error="General LLM Error. A sub-answer could not be constructed and the sub-question will be ignored.",
)
def generate_sub_answer(
state: AnswerQuestionState,
@@ -80,8 +57,6 @@ def generate_sub_answer(
if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
cited_documents: list = []
log_results = "No documents retrieved"
write_custom_event(
"sub_answers",
AgentAnswerPiece(
@@ -104,67 +79,41 @@ def generate_sub_answer(
response: list[str | list[str | dict[str, Any]]] = []
dispatch_timings: list[float] = []
agent_error: AgentErrorLoggingFormat | None = None
try:
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
for message in fast_llm.stream(
prompt=msg,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
logger.error("LLM Timeout Error - generate sub answer")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
logger.error("LLM Rate Limit Error - generate sub answer")
response.append(content)
if agent_error:
answer_str = LLM_ANSWER_ERROR_MESSAGE
cited_documents = []
log_results = (
agent_error.error_result
or "Sub-answer generation failed due to LLM error"
)
answer_str = merge_message_runs(response, chunk_separator="")[0].content
logger.debug(
f"Average dispatch time: {sum(dispatch_timings) / len(dispatch_timings)}"
)
else:
answer_str = merge_message_runs(response, chunk_separator="")[0].content
answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
log_results = None
answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
@@ -182,7 +131,7 @@ def generate_sub_answer(
graph_component="initial - generate individual sub answer",
node_name="generate sub answer",
node_start_time=node_start_time,
result=log_results or "",
result="",
)
],
)

View File

@@ -42,8 +42,10 @@ class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):
class SubQuestionAnsweringInput(SubgraphCoreState):
question: str
question_id: str
question: str = ""
question_id: str = (
"" # 0_0 is original question, everything else is <level>_<question_num>.
)
# level 0 is original question and first decomposition, level 1 is follow up, etc
# question_num is a unique number per original question per level.

View File

@@ -26,18 +26,7 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
@@ -53,16 +42,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_ci
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
)
from onyx.context.search.models import InferenceSection
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
@@ -72,12 +57,6 @@ from onyx.prompts.agent_search import (
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. The initial answer could not be generated.",
rate_limit="LLM Rate Limit Error. The initial answer could not be generated.",
general_error="General LLM Error. The initial answer could not be generated.",
)
def generate_initial_answer(
state: SubQuestionRetrievalState,
@@ -245,82 +224,30 @@ def generate_initial_answer(
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
agent_error: AgentErrorLoggingFormat | None = None
try:
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
start_stream_token = datetime.now()
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate initial answer")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate initial answer")
if agent_error:
write_custom_event(
"initial_agent_answer",
StreamingError(
error=AGENT_LLM_TIMEOUT_MESSAGE,
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
return InitialAnswerUpdate(
initial_answer=None,
error=AgentErrorLoggingFormat(
error_message=agent_error.error_message or "An LLM error occurred",
error_type=agent_error.error_type,
error_result=agent_error.error_result,
),
initial_agent_stats=None,
generated_sub_questions=sub_questions,
agent_base_end_time=None,
agent_base_metrics=None,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="generate initial answer",
node_start_time=node_start_time,
result=agent_error.error_result or "An LLM error occurred",
)
],
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
logger.debug(
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"

View File

@@ -25,7 +25,7 @@ def validate_initial_answer(
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
)
verdict = True # not actually required as already streamed out. Refinement will do similar
verdict = True
return InitialAnswerQualityUpdate(
initial_answer_quality_eval=verdict,

View File

@@ -23,18 +23,6 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
@@ -45,11 +33,6 @@ from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
)
@@ -60,12 +43,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. Sub-questions could not be generated.",
rate_limit="LLM Rate Limit Error. Sub-questions could not be generated.",
general_error="General LLM Error. Sub-questions could not be generated.",
)
def decompose_orig_question(
state: SubQuestionRetrievalState,
@@ -135,35 +112,11 @@ def decompose_orig_question(
)
# dispatches custom events for subquestion tokens, adding in subquestion ids.
agent_error: AgentErrorLoggingFormat | None = None
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = dispatch_separated(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),
)
except LLMTimeoutError as e:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - decompose orig question")
raise e # fail loudly on this critical step
except LLMRateLimitError as e:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - decompose orig question")
raise e
streamed_tokens = dispatch_separated(
model.stream(msg),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),
)
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
@@ -172,19 +125,19 @@ def decompose_orig_question(
)
write_custom_event("stream_finished", stop_event, writer)
if agent_error:
initial_sub_questions: list[str] = []
log_result = agent_error.error_result
else:
deomposition_response = merge_content(*streamed_tokens)
deomposition_response = merge_content(*streamed_tokens)
list_of_subqs = cast(str, deomposition_response).split("\n")
# this call should only return strings. Commenting out for efficiency
# assert [type(tok) == str for tok in streamed_tokens]
initial_sub_questions = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
log_result = f"decomposed original question into {len(initial_sub_questions)} subquestions"
# use no-op cast() instead of str() which runs code
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
list_of_subqs = cast(str, deomposition_response).split("\n")
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
return InitialQuestionDecompositionUpdate(
initial_sub_questions=initial_sub_questions,
initial_sub_questions=decomp_list,
agent_start_time=agent_start_time,
agent_refined_start_time=None,
agent_refined_end_time=None,
@@ -198,7 +151,7 @@ def decompose_orig_question(
graph_component="initial - generate sub answers",
node_name="decompose original question",
node_start_time=node_start_time,
result=log_result,
result=f"decomposed original question into {len(decomp_list)} subquestions",
)
],
)

View File

@@ -252,7 +252,9 @@ if __name__ == "__main__":
db_session, primary_llm, fast_llm, search_request
)
inputs = MainInput(log_messages=[])
inputs = MainInput(
base_question=graph_config.inputs.search_request.query, log_messages=[]
)
for thing in compiled_graph.stream(
input=inputs,

View File

@@ -1,7 +1,6 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
@@ -11,37 +10,14 @@ from onyx.agents.agent_search.deep_search.main.states import (
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out, and the answers could not be compared.",
rate_limit="The LLM encountered a rate limit, and the answers could not be compared.",
general_error="The LLM encountered an error, and the answers could not be compared.",
)
def compare_answers(
@@ -64,46 +40,15 @@ def compare_answers(
msg = [HumanMessage(content=compare_answers_prompt)]
agent_error: AgentErrorLoggingFormat | None = None
# Get the rewritten queries in a defined format
model = graph_config.tooling.fast_llm
resp: BaseMessage | None = None
refined_answer_improvement: bool | None = None
# no need to stream this
try:
resp = model.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
)
resp = model.invoke(msg)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - compare answers")
# continue as True in this support step
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - compare answers")
# continue as True in this support step
if agent_error or resp is None:
refined_answer_improvement = True
if agent_error:
log_result = agent_error.error_result
else:
log_result = "An answer could not be generated."
else:
refined_answer_improvement = (
isinstance(resp.content, str) and "yes" in resp.content.lower()
)
log_result = f"Answer comparison: {refined_answer_improvement}"
refined_answer_improvement = (
isinstance(resp.content, str) and "yes" in resp.content.lower()
)
write_custom_event(
"refined_answer_improvement",
@@ -120,7 +65,7 @@ def compare_answers(
graph_component="main",
node_name="compare answers",
node_start_time=node_start_time,
result=log_result,
result=f"Answer comparison: {refined_answer_improvement}",
)
],
)

View File

@@ -21,18 +21,6 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
format_entity_term_extraction,
@@ -42,25 +30,10 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT,
)
from onyx.tools.models import ToolCallKickoff
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The sub-questions could not be generated.",
rate_limit="The LLM encountered a rate limit. The sub-questions could not be generated.",
general_error="The LLM encountered an error. The sub-questions could not be generated.",
)
def create_refined_sub_questions(
@@ -123,65 +96,29 @@ def create_refined_sub_questions(
# Grader
model = graph_config.tooling.fast_llm
agent_error: AgentErrorLoggingFormat | None = None
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = dispatch_separated(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),
)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - create refined sub questions")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - create refined sub questions")
if agent_error:
refined_sub_question_dict: dict[int, RefinementSubQuestion] = {}
log_result = agent_error.error_result
write_custom_event(
"refined_sub_question_creation_error",
StreamingError(
error="Your LLM was not able to create refined sub questions in time and timed out. Please try again.",
),
writer,
)
streamed_tokens = dispatch_separated(
model.stream(msg),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),
)
response = merge_content(*streamed_tokens)
if isinstance(response, str):
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
else:
response = merge_content(*streamed_tokens)
raise ValueError("LLM response is not a string")
if isinstance(response, str):
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
else:
raise ValueError("LLM response is not a string")
refined_sub_question_dict = {}
for sub_question_num, sub_question in enumerate(parsed_response):
refined_sub_question = RefinementSubQuestion(
sub_question=sub_question,
sub_question_id=make_question_id(1, sub_question_num + 1),
verified=False,
answered=False,
answer="",
)
refined_sub_question_dict = {}
for sub_question_num, sub_question in enumerate(parsed_response):
refined_sub_question = RefinementSubQuestion(
sub_question=sub_question,
sub_question_id=make_question_id(1, sub_question_num + 1),
verified=False,
answered=False,
answer="",
)
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
log_result = f"Created {len(refined_sub_question_dict)} refined sub questions"
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
return RefinedQuestionDecompositionUpdate(
refined_sub_questions=refined_sub_question_dict,
@@ -191,7 +128,7 @@ def create_refined_sub_questions(
graph_component="main",
node_name="create refined sub questions",
node_start_time=node_start_time,
result=log_result,
result=f"Created {len(refined_sub_question_dict)} refined sub questions",
)
],
)

View File

@@ -26,19 +26,6 @@ def decide_refinement_need(
decision = True # TODO: just for current testing purposes
if state.error:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=False,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="decide refinement need",
node_start_time=node_start_time,
result="Timeout Error",
)
],
)
log_messages = [
get_langgraph_node_log_string(
graph_component="main",

View File

@@ -21,9 +21,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
)
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
@@ -84,7 +81,6 @@ def extract_entities_terms(
# Grader
llm_response = fast_llm.invoke(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (

View File

@@ -11,6 +11,7 @@ from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RefinedAnswerUpdate,
@@ -22,18 +23,7 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import InferenceSection
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
@@ -53,14 +43,8 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
)
@@ -72,15 +56,6 @@ from onyx.prompts.agent_search import (
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The refined answer could not be generated.",
rate_limit="The LLM encountered a rate limit. The refined answer could not be generated.",
general_error="The LLM encountered an error. The refined answer could not be generated.",
)
def generate_refined_answer(
@@ -256,80 +231,28 @@ def generate_refined_answer(
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
agent_error: AgentErrorLoggingFormat | None = None
try:
for message in model.stream(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=1,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate refined answer")
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate refined answer")
if agent_error:
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
StreamingError(
error=AGENT_LLM_TIMEOUT_MESSAGE,
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=1,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
return RefinedAnswerUpdate(
refined_answer=None,
refined_answer_quality=False, # TODO: replace this with the actual check value
refined_agent_stats=None,
agent_refined_end_time=None,
agent_refined_metrics=AgentRefinedMetrics(
refined_doc_boost_factor=0.0,
refined_question_boost_factor=0.0,
duration_s=None,
),
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate refined answer",
node_start_time=node_start_time,
result=agent_error.error_result or "An LLM error occurred",
)
],
)
end_stream_token = datetime.now()
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
streamed_tokens.append(content)
logger.debug(
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
@@ -343,6 +266,49 @@ def generate_refined_answer(
revision_question_efficiency=revision_question_efficiency,
)
logger.debug(f"\n\n---INITIAL ANSWER ---\n\n Answer:\n Agent: {initial_answer}")
logger.debug("-" * 10)
logger.debug(f"\n\n---REVISED AGENT ANSWER ---\n\n Answer:\n Agent: {answer}")
logger.debug("-" * 100)
if state.initial_agent_stats:
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio", "--"
)
initial_support_boost_factor = (
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
)
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
"num_verified_documents", "--"
)
initial_verified_docs_avg_score = (
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
)
initial_sub_questions_verified_docs = (
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
)
logger.debug("INITIAL AGENT STATS")
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
logger.debug(
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
)
logger.debug(
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
)
if refined_agent_stats:
logger.debug("-" * 10)
logger.debug("REFINED AGENT STATS")
logger.debug(
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
)
logger.debug(
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
)
agent_refined_end_time = datetime.now()
if state.agent_refined_start_time:
agent_refined_duration = (

View File

@@ -17,7 +17,6 @@ from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
@@ -77,7 +76,6 @@ class InitialAnswerUpdate(LoggerUpdate):
"""
initial_answer: str | None = None
error: AgentErrorLoggingFormat | None = None
initial_agent_stats: InitialAgentResultStats | None = None
generated_sub_questions: list[str] = []
agent_base_end_time: datetime | None = None
@@ -90,7 +88,6 @@ class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
"""
refined_answer: str | None = None
error: AgentErrorLoggingFormat | None = None
refined_agent_stats: RefinedAgentStats | None = None
refined_answer_quality: bool = False

View File

@@ -16,40 +16,14 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
QueryExpansionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
QUERY_REWRITING_PROMPT,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="Query rewriting failed due to LLM timeout - the original question will be used.",
rate_limit="Query rewriting failed due to LLM rate limit - the original question will be used.",
general_error="Query rewriting failed due to LLM error - the original question will be used.",
)
def expand_queries(
@@ -80,43 +54,13 @@ def expand_queries(
)
]
agent_error: AgentErrorLoggingFormat | None = None
llm_response_list: list[BaseMessage_Content] = []
llm_response_list = dispatch_separated(
llm.stream(prompt=msg), dispatch_subquery(level, question_num, writer)
)
try:
llm_response_list = dispatch_separated(
llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)
except LLMTimeoutError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - expand queries")
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
except LLMRateLimitError:
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - expand queries")
# use subquestion as query if query generation fails
if agent_error:
llm_response = ""
rewritten_queries = [question]
log_result = agent_error.error_result
else:
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[
0
].content
rewritten_queries = llm_response.split("\n")
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
rewritten_queries = llm_response.split("\n")
return QueryExpansionUpdate(
expanded_queries=rewritten_queries,
@@ -125,7 +69,7 @@ def expand_queries(
graph_component="shared - expanded retrieval",
node_name="expand queries",
node_start_time=node_start_time,
result=log_result,
result=f"Number of expanded queries: {len(rewritten_queries)}",
)
],
)

View File

@@ -1,6 +1,5 @@
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables.config import RunnableConfig
@@ -11,41 +10,12 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
DocVerificationUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLoggingFormat
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
DOCUMENT_VERIFICATION_PROMPT,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The document could not be verified. The document will be treated as 'relevant'",
rate_limit="The LLM encountered a rate limit. The document could not be verified. The document will be treated as 'relevant'",
general_error="The LLM encountered an error. The document could not be verified. The document will be treated as 'relevant'",
)
def verify_documents(
@@ -56,7 +26,7 @@ def verify_documents(
Args:
state (DocVerificationInput): The current state
config (RunnableConfig): Configuration containing AgentSearchConfig
config (RunnableConfig): Configuration containing ProSearchConfig
Updates:
verified_documents: list[InferenceSection]
@@ -81,42 +51,11 @@ def verify_documents(
)
]
agent_error: AgentErrorLoggingFormat | None = None
response: BaseMessage | None = None
response = fast_llm.invoke(msg)
try:
response = fast_llm.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
)
except LLMTimeoutError:
# In this case, we decide to continue and don't raise an error, as
# little harm in letting some docs through that are less relevant.
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - verify documents")
except LLMRateLimitError:
# In this case, we decide to continue and don't raise an error, as
# little harm in letting some docs through that are less relevant.
agent_error = AgentErrorLoggingFormat(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - verify documents")
if agent_error or response is None:
verified_documents = [retrieved_document_to_verify]
else:
verified_documents = []
if isinstance(response.content, str) and binary_string_test(
text=response.content, positive_value=AGENT_POSITIVE_VALUE_STR
):
verified_documents.append(retrieved_document_to_verify)
verified_documents = []
if isinstance(response.content, str) and "yes" in response.content.lower():
verified_documents.append(retrieved_document_to_verify)
return DocVerificationUpdate(
verified_documents=verified_documents,

View File

@@ -21,13 +21,9 @@ from onyx.context.search.models import InferenceSection
class ExpandedRetrievalInput(SubgraphCoreState):
# exception from 'no default value'for LangGraph input states
# Here, sub_question_id default Nonoe implies usage for the
# original question. This is sometimes needed for nested sub-graphs
question: str = ""
base_search: bool = False
sub_question_id: str | None = None
question: str
base_search: bool
## Update/Return States
@@ -92,4 +88,4 @@ class DocVerificationInput(ExpandedRetrievalInput):
class RetrievalInput(ExpandedRetrievalInput):
query_to_retrieve: str
query_to_retrieve: str = ""

View File

@@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
main_graph_builder as main_graph_builder_a,
)
from onyx.agents.agent_search.deep_search.main.states import (
MainInput as MainInput,
MainInput as MainInput_a,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
@@ -21,7 +21,6 @@ from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
@@ -34,7 +33,6 @@ from onyx.llm.factory import get_default_llms
from onyx.tools.tool_runner import ToolCallKickoff
from onyx.utils.logger import setup_logger
logger = setup_logger()
_COMPILED_GRAPH: CompiledStateGraph | None = None
@@ -74,15 +72,13 @@ def _parse_agent_event(
return cast(AnswerPacket, event["data"])
elif event["name"] == "refined_answer_improvement":
return cast(RefinedAnswerImprovement, event["data"])
elif event["name"] == "refined_sub_question_creation_error":
return cast(StreamingError, event["data"])
return None
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
graph_input: BasicInput | MainInput,
graph_input: BasicInput | MainInput_a,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
for event in compiled_graph.stream(
@@ -96,7 +92,7 @@ def manage_sync_streaming(
def run_graph(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
input: BasicInput | MainInput,
input: BasicInput | MainInput_a,
) -> AnswerStream:
config.behavior.perform_initial_search_decomposition = (
INITIAL_SEARCH_DECOMPOSITION_ENABLED
@@ -127,7 +123,9 @@ def run_main_graph(
) -> AnswerStream:
compiled_graph = load_compiled_graph()
input = MainInput(log_messages=[])
input = MainInput_a(
base_question=config.inputs.search_request.query, log_messages=[]
)
# Agent search is not a Tool per se, but this is helpful for the frontend
yield ToolCallKickoff(
@@ -174,7 +172,9 @@ if __name__ == "__main__":
# search_request.persona = get_persona_by_id(1, None, db_session)
# config.perform_initial_search_path_decision = False
config.behavior.perform_initial_search_decomposition = True
input = MainInput(log_messages=[])
input = MainInput_a(
base_question=config.inputs.search_request.query, log_messages=[]
)
tool_responses: list = []
for output in run_graph(compiled_graph, config, input):

View File

@@ -150,17 +150,3 @@ def get_prompt_enrichment_components(
history=history,
date_str=date_str,
)
def binary_string_test(text: str, positive_value: str = "yes") -> bool:
"""
Tests if a string contains a positive value (case-insensitive).
Args:
text: The string to test
positive_value: The value to look for (defaults to "yes")
Returns:
True if the positive value is found in the text
"""
return positive_value.lower() in text.lower()

View File

@@ -1,17 +0,0 @@
from enum import Enum
AGENT_LLM_TIMEOUT_MESSAGE = "The agent timed out. Please try again."
AGENT_LLM_ERROR_MESSAGE = "The agent encountered an error. Please try again."
AGENT_LLM_RATELIMIT_MESSAGE = (
"The agent encountered a rate limit error. Please try again."
)
LLM_ANSWER_ERROR_MESSAGE = "The question was not answered due to an LLM error."
AGENT_POSITIVE_VALUE_STR = "yes"
AGENT_NEGATIVE_VALUE_STR = "no"
class AgentLLMErrorType(str, Enum):
TIMEOUT = "timeout"
RATE_LIMIT = "rate_limit"
GENERAL_ERROR = "general_error"

View File

@@ -1,5 +1,3 @@
from typing import Any
from pydantic import BaseModel
from onyx.agents.agent_search.deep_search.main.models import (
@@ -58,12 +56,6 @@ class InitialAgentResultStats(BaseModel):
agent_effectiveness: dict[str, float | int | None]
class AgentErrorLoggingFormat(BaseModel):
error_message: str
error_type: str
error_result: str | None = None
class RefinedAgentStats(BaseModel):
revision_doc_efficiency: float | None
revision_question_efficiency: float | None
@@ -134,12 +126,3 @@ class AgentPromptEnrichmentComponents(BaseModel):
persona_prompts: PersonaPromptExpressions
history: str
date_str: str
class LLMNodeErrorStrings(BaseModel):
timeout: str = "LLM Timeout Error"
rate_limit: str = "LLM Rate Limit Error"
general_error: str = "General LLM Error"
BaseMessage_Content = str | list[str | dict[str, Any]]

View File

@@ -20,7 +20,6 @@ from onyx.agents.agent_search.models import GraphInputs
from onyx.agents.agent_search.models import GraphPersistence
from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
@@ -35,9 +34,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
)
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
@@ -50,8 +46,6 @@ from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import Persona
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.llm.interfaces import LLM
from onyx.prompts.agent_search import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
@@ -71,9 +65,8 @@ from onyx.tools.tool_implementations.search.search_tool import (
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
logger = setup_logger()
BaseMessage_Content = str | list[str | dict[str, Any]]
# Post-processing
@@ -379,24 +372,8 @@ def summarize_history(
)
)
try:
history_response = llm.invoke(
history_context_prompt,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
)
except LLMTimeoutError:
logger.error("LLM Timeout Error - summarize history")
return (
history # this is what is done at this point anyway, so we default to this
)
except LLMRateLimitError:
logger.error("LLM Rate Limit Error - summarize history")
return (
history # this is what is done at this point anyway, so we default to this
)
history_response = llm.invoke(history_context_prompt)
assert isinstance(history_response.content, str)
return history_response.content

View File

@@ -1,41 +1,56 @@
from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery.beat import PersistentScheduler # type: ignore
from celery.signals import beat_init
from celery.utils.log import get_task_logger
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
task_logger = get_task_logger(__name__)
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.beat")
class DynamicTenantScheduler(PersistentScheduler):
"""This scheduler is useful because we can dynamically adjust task generation rates
through it."""
RELOAD_INTERVAL = 60
def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.info("Initializing DynamicTenantScheduler")
super().__init__(*args, **kwargs)
self._reload_interval = timedelta(minutes=2)
self.last_beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
self._reload_interval = timedelta(
seconds=DynamicTenantScheduler.RELOAD_INTERVAL
)
self._last_reload = self.app.now() - self._reload_interval
# Let the parent class handle store initialization
self.setup_schedule()
self._try_updating_schedule()
logger.info(f"Set reload interval to {self._reload_interval}")
task_logger.info(
f"DynamicTenantScheduler initialized: reload_interval={self._reload_interval}"
)
def setup_schedule(self) -> None:
logger.info("Setting up initial schedule")
super().setup_schedule()
logger.info("Initial schedule setup complete")
def tick(self) -> float:
retval = super().tick()
@@ -44,36 +59,35 @@ class DynamicTenantScheduler(PersistentScheduler):
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reload interval reached, initiating task update")
task_logger.debug("Reload interval reached, initiating task update")
try:
self._try_updating_schedule()
except (AttributeError, KeyError) as e:
logger.exception(f"Failed to process task configuration: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error updating tasks: {str(e)}")
except (AttributeError, KeyError):
task_logger.exception("Failed to process task configuration")
except Exception:
task_logger.exception("Unexpected error updating tasks")
self._last_reload = now
logger.info("Task update completed, reset reload timer")
return retval
def _generate_schedule(
self, tenant_ids: list[str] | list[None]
self, tenant_ids: list[str] | list[None], beat_multiplier: float
) -> dict[str, dict[str, Any]]:
"""Given a list of tenant id's, generates a new beat schedule for celery."""
logger.info("Fetching tasks to schedule")
new_schedule: dict[str, dict[str, Any]] = {}
if MULTI_TENANT:
# cloud tasks only need the single task beat across all tenants
# cloud tasks are system wide and thus only need to be on the beat schedule
# once for all tenants
get_cloud_tasks_to_schedule = fetch_versioned_implementation(
"onyx.background.celery.tasks.beat_schedule",
"get_cloud_tasks_to_schedule",
)
cloud_tasks_to_schedule: list[
dict[str, Any]
] = get_cloud_tasks_to_schedule()
cloud_tasks_to_schedule: list[dict[str, Any]] = get_cloud_tasks_to_schedule(
beat_multiplier
)
for task in cloud_tasks_to_schedule:
task_name = task["name"]
cloud_task = {
@@ -82,11 +96,14 @@ class DynamicTenantScheduler(PersistentScheduler):
"kwargs": task.get("kwargs", {}),
}
if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}")
task_logger.debug(f"Adding options to task {task_name}: {options}")
cloud_task["options"] = options
new_schedule[task_name] = cloud_task
# regular task beats are multiplied across all tenants
# note that currently this just schedules for a single tenant in self hosted
# and doesn't do anything in the cloud because it's much more scalable
# to schedule a single cloud beat task to dispatch per tenant tasks.
get_tasks_to_schedule = fetch_versioned_implementation(
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
@@ -95,7 +112,7 @@ class DynamicTenantScheduler(PersistentScheduler):
for tenant_id in tenant_ids:
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
logger.info(
task_logger.debug(
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
continue
@@ -104,14 +121,14 @@ class DynamicTenantScheduler(PersistentScheduler):
task_name = task["name"]
tenant_task_name = f"{task['name']}-{tenant_id}"
logger.debug(f"Creating task configuration for {tenant_task_name}")
task_logger.debug(f"Creating task configuration for {tenant_task_name}")
tenant_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
logger.debug(
task_logger.debug(
f"Adding options to task {tenant_task_name}: {options}"
)
tenant_task["options"] = options
@@ -121,44 +138,57 @@ class DynamicTenantScheduler(PersistentScheduler):
def _try_updating_schedule(self) -> None:
"""Only updates the actual beat schedule on the celery app when it changes"""
do_update = False
logger.info("_try_updating_schedule starting")
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
task_logger.debug("_try_updating_schedule starting")
tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} IDs")
task_logger.debug(f"Found {len(tenant_ids)} IDs")
# get current schedule and extract current tenants
current_schedule = self.schedule.items()
# there are no more per tenant beat tasks, so comment this out
# NOTE: we may not actualy need this scheduler any more and should
# test reverting to a regular beat schedule implementation
# get potential new state
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
if beat_multiplier_raw is not None:
try:
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
beat_multiplier = float(beat_multiplier_bytes.decode())
except ValueError:
task_logger.error(
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
)
# current_tenants = set()
# for task_name, _ in current_schedule:
# task_name = cast(str, task_name)
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
# continue
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)
# if "_" in task_name:
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
# # -> "12345678-abcd-efgh-ijkl-12345678"
# current_tenants.add(task_name.split("_")[-1])
# logger.info(f"Found {len(current_tenants)} existing items in schedule")
# if the schedule or beat multiplier has changed, update
while True:
if beat_multiplier != self.last_beat_multiplier:
do_update = True
break
# for tenant_id in tenant_ids:
# if tenant_id not in current_tenants:
# logger.info(f"Processing new tenant: {tenant_id}")
if not DynamicTenantScheduler._compare_schedules(
current_schedule, new_schedule
):
do_update = True
break
new_schedule = self._generate_schedule(tenant_ids)
break
if DynamicTenantScheduler._compare_schedules(current_schedule, new_schedule):
logger.info(
"_try_updating_schedule: Current schedule is up to date, no changes needed"
if not do_update:
# exit early if nothing changed
task_logger.info(
f"_try_updating_schedule - Schedule unchanged: "
f"tasks={len(new_schedule)} "
f"beat_multiplier={beat_multiplier}"
)
return
logger.info(
# schedule needs updating
task_logger.debug(
"Schedule update required",
extra={
"new_tasks": len(new_schedule),
@@ -185,11 +215,19 @@ class DynamicTenantScheduler(PersistentScheduler):
# Ensure changes are persisted
self.sync()
logger.info("_try_updating_schedule: Schedule updated successfully")
task_logger.info(
f"_try_updating_schedule - Schedule updated: "
f"prev_num_tasks={len(current_schedule)} "
f"prev_beat_multiplier={self.last_beat_multiplier} "
f"tasks={len(new_schedule)} "
f"beat_multiplier={beat_multiplier}"
)
self.last_beat_multiplier = beat_multiplier
@staticmethod
def _compare_schedules(schedule1: dict, schedule2: dict) -> bool:
"""Compare schedules to determine if an update is needed.
"""Compare schedules by task name only to determine if an update is needed.
True if equivalent, False if not."""
current_tasks = set(name for name, _ in schedule1)
new_tasks = set(schedule2.keys())
@@ -201,7 +239,7 @@ class DynamicTenantScheduler(PersistentScheduler):
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
logger.info("beat_init signal received.")
task_logger.info("beat_init signal received.")
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)

View File

@@ -84,8 +84,10 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
EXTRA_CONCURRENCY = 4 # small extra fudge factor for connection limits
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -1,3 +1,4 @@
import copy
from datetime import timedelta
from typing import Any
@@ -18,242 +19,184 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# hack to slow down task dispatch in the cloud until
# we have a better implementation (backpressure, etc)
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 4
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
# tasks that run in either self-hosted on cloud
beat_task_templates: list[dict] = []
beat_task_templates.extend(
[
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-pruning",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-vespa-sync",
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"schedule": timedelta(minutes=5),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
]
)
# Only add the LLM model update task if the API URL is configured
if LLM_MODEL_UPDATE_API_URL:
beat_task_templates.append(
{
"name": "check-for-llm-model-update",
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"schedule": timedelta(hours=1), # Check every hour
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
cloud_task: dict[str, Any] = {}
# constant options for cloud beat task generators
task_schedule: timedelta = task["schedule"]
cloud_task["schedule"] = task_schedule
cloud_task["options"] = {}
cloud_task["options"]["priority"] = OnyxCeleryPriority.HIGHEST
cloud_task["options"]["expires"] = BEAT_EXPIRES_DEFAULT
# settings dependent on the original task
cloud_task["name"] = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_{task['name']}"
cloud_task["task"] = OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR
cloud_task["kwargs"] = {}
cloud_task["kwargs"]["task_name"] = task["task"]
optional_fields = ["queue", "priority", "expires"]
for field in optional_fields:
if field in task["options"]:
cloud_task["kwargs"][field] = task["options"][field]
return cloud_task
# tasks that only run in the cloud
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered
# by the DynamicTenantScheduler
cloud_tasks_to_schedule = [
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be seen
# by the DynamicTenantScheduler as system wide task and not a per tenant task
beat_system_tasks: list[dict] = [
# cloud specific tasks
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
"schedule": timedelta(hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"schedule": timedelta(hours=1),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
# remaining tasks are cloud generators for per tenant tasks
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_INDEXING,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-connector-deletion",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-vespa-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-prune",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_PRUNING,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-vespa-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.MONITOR_VESPA_SYNC,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=30 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-external-group-sync",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-background-processes",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(minutes=5 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.LOW,
},
},
]
if LLM_MODEL_UPDATE_API_URL:
cloud_tasks_to_schedule.append(
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-llm-model-update",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(
hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER
), # Check every hour
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"priority": OnyxCeleryPriority.LOW,
},
}
)
# tasks that run in either self-hosted on cloud
tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
tasks_to_schedule.extend(
[
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-pruning",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-vespa-sync",
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"schedule": timedelta(minutes=15),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
]
)
# Only add the LLM model update task if the API URL is configured
if LLM_MODEL_UPDATE_API_URL:
tasks_to_schedule.append(
{
"name": "check-for-llm-model-update",
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
"schedule": timedelta(hours=1), # Check every hour
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
tasks_to_schedule = beat_task_templates
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
return cloud_tasks_to_schedule
def generate_cloud_tasks(
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float
) -> list[dict[str, Any]]:
"""
beat_tasks: system wide tasks that can be sent as is
beat_templates: task templates that will be transformed into per tenant tasks via
the cloud_beat_task_generator
beat_multiplier: a multiplier that can be applied on top of the task schedule
to speed up or slow down the task generation rate. useful in production.
Returns a list of cloud tasks, which consists of incoming tasks + tasks generated
from incoming templates.
"""
if beat_multiplier <= 0:
raise ValueError("beat_multiplier must be positive!")
# start with the incoming beat tasks
cloud_tasks: list[dict] = copy.deepcopy(beat_tasks)
# generate our cloud tasks from the templates
for beat_template in beat_templates:
cloud_task = make_cloud_generator_task(beat_template)
cloud_tasks.append(cloud_task)
# factor in the cloud multiplier
for cloud_task in cloud_tasks:
cloud_task["schedule"] = cloud_task["schedule"] * beat_multiplier
return cloud_tasks
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
return generate_cloud_tasks(beat_system_tasks, beat_task_templates, beat_multiplier)
def get_tasks_to_schedule() -> list[dict[str, Any]]:

View File

@@ -186,7 +186,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
sync_type=SyncType.CONNECTOR_DELETION,
)
except Exception:
pass
task_logger.exception("insert_sync_record exceptioned.")
except TaskDependencyError:
redis_connector.delete.set_fence(None)

View File

@@ -228,12 +228,15 @@ def try_creating_permissions_sync_task(
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_PERMISSIONS,
)
try:
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_PERMISSIONS,
)
except Exception:
task_logger.exception("insert_sync_record exceptioned.")
# set a basic fence to start
redis_connector.permissions.set_active()
@@ -257,11 +260,10 @@ def try_creating_permissions_sync_task(
)
# fill in the celery task id
redis_connector.permissions.set_active()
payload.celery_task_id = result.id
redis_connector.permissions.set_fence(payload)
payload_id = payload.celery_task_id
payload_id = payload.id
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
return None
@@ -290,6 +292,8 @@ def connector_permission_sync_generator_task(
This task assumes that the task has already been properly fenced
"""
payload_id: str | None = None
LoggerContextVars.reset()
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
@@ -332,9 +336,12 @@ def connector_permission_sync_generator_task(
sleep(1)
continue
payload_id = payload.id
logger.info(
f"connector_permission_sync_generator_task - Fence found, continuing...: "
f"fence={redis_connector.permissions.fence_key}"
f"fence={redis_connector.permissions.fence_key} "
f"payload_id={payload.id}"
)
break
@@ -413,7 +420,9 @@ def connector_permission_sync_generator_task(
redis_connector.permissions.generator_complete = tasks_generated
except Exception as e:
task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}")
task_logger.exception(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
)
redis_connector.permissions.generator_clear()
redis_connector.permissions.taskset_clear()
@@ -423,6 +432,10 @@ def connector_permission_sync_generator_task(
if lock.owned():
lock.release()
task_logger.info(
f"Permission sync finished: cc_pair={cc_pair_id} payload_id={payload.id}"
)
@shared_task(
name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
@@ -446,14 +459,15 @@ def update_external_document_permissions_task(
)
doc_id = document_external_access.doc_id
external_access = document_external_access.external_access
try:
with get_session_with_tenant(tenant_id) as db_session:
# Add the users to the DB if they don't exist
batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),
continue_on_error=True,
)
# Then we upsert the document's external permissions in postgres
# Then upsert the document's external permissions
created_new_doc = upsert_document_external_perms(
db_session=db_session,
doc_id=doc_id,
@@ -477,11 +491,11 @@ def update_external_document_permissions_task(
f"action=update_permissions "
f"elapsed={elapsed:.2f}"
)
except Exception:
task_logger.exception(
f"Exception in update_external_document_permissions_task: "
f"connector_id={connector_id} "
f"doc_id={doc_id}"
f"connector_id={connector_id} doc_id={doc_id}"
)
return False
@@ -659,7 +673,7 @@ def validate_permission_sync_fence(
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
)
# we're only active if tasks_scanned > 0 and tasks_not_in_celery == 0
# we're active if there are still tasks to run and those tasks all exist in celery
if tasks_scanned > 0 and tasks_not_in_celery == 0:
redis_connector.permissions.set_active()
return
@@ -680,7 +694,8 @@ def validate_permission_sync_fence(
"validate_permission_sync_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
f"fence={fence_key} "
f"payload_id={payload.id}"
)
redis_connector.permissions.reset()

View File

@@ -2,15 +2,17 @@ import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import ValidationError
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
@@ -32,7 +34,9 @@ from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_cc_pair_as_external_group_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_with_tenant
@@ -49,7 +53,8 @@ from onyx.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.utils import make_short_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -107,11 +112,11 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
# r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
@@ -149,30 +154,32 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
lock_beat.reacquire()
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_external_group_sync_task(
payload_id = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
if not payload_id:
continue
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
task_logger.info(
f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}"
)
# we want to run this less frequently than the overall task
# lock_beat.reacquire()
# if not r.exists(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
# # clear any indexing fences that don't have associated celery tasks in progress
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# # or be currently executing
# try:
# validate_external_group_sync_fences(
# tenant_id, self.app, r, r_celery, lock_beat
# )
# except Exception:
# task_logger.exception(
# "Exception while validating external group sync fences"
# )
lock_beat.reacquire()
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
# clear fences that don't have associated celery tasks in progress
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
validate_external_group_sync_fences(
tenant_id, self.app, r, r_replica, r_celery, lock_beat
)
except Exception:
task_logger.exception(
"Exception while validating external group sync fences"
)
# r.set(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=60)
r.set(OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=300)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -191,9 +198,11 @@ def try_creating_external_group_sync_task(
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> int | None:
) -> str | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
payload_id: str | None = None
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
@@ -215,11 +224,28 @@ def try_creating_external_group_sync_task(
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
try:
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_GROUP,
)
except Exception:
task_logger.exception("insert_sync_record exceptioned.")
# Signal active before creating fence
redis_connector.external_group_sync.set_active()
payload = RedisConnectorExternalGroupSyncPayload(
id=make_short_id(),
submitted=datetime.now(timezone.utc),
started=None,
celery_task_id=None,
)
redis_connector.external_group_sync.set_fence(payload)
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
@@ -234,17 +260,10 @@ def try_creating_external_group_sync_task(
priority=OnyxCeleryPriority.HIGH,
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.EXTERNAL_GROUP,
)
payload.celery_task_id = result.id
redis_connector.external_group_sync.set_fence(payload)
payload_id = payload.id
except Exception:
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
@@ -254,7 +273,7 @@ def try_creating_external_group_sync_task(
if lock.owned():
lock.release()
return 1
return payload_id
@shared_task(
@@ -312,7 +331,8 @@ def connector_external_group_sync_generator_task(
logger.info(
f"connector_external_group_sync_generator_task - Fence found, continuing...: "
f"fence={redis_connector.external_group_sync.fence_key}"
f"fence={redis_connector.external_group_sync.fence_key} "
f"payload_id={payload.id}"
)
break
@@ -381,7 +401,7 @@ def connector_external_group_sync_generator_task(
)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
)
with get_session_with_tenant(tenant_id) as db_session:
@@ -401,32 +421,41 @@ def connector_external_group_sync_generator_task(
if lock.owned():
lock.release()
task_logger.info(
f"External group sync finished: cc_pair={cc_pair_id} payload_id={payload.id}"
)
def validate_external_group_sync_fences(
tenant_id: str | None,
celery_app: Celery,
r: Redis,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
reserved_sync_tasks = celery_get_unacked_task_ids(
reserved_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
# validate all existing external group sync tasks
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")
if not key_str.startswith(RedisConnectorExternalGroupSync.FENCE_PREFIX):
continue
validate_external_group_sync_fence(
tenant_id,
key_bytes,
reserved_tasks,
r_celery,
)
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_external_group_sync_fence(
tenant_id,
key_bytes,
reserved_sync_tasks,
r_celery,
db_session,
)
return
@@ -435,7 +464,6 @@ def validate_external_group_sync_fence(
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
@@ -478,26 +506,26 @@ def validate_external_group_sync_fence(
if not redis_connector.external_group_sync.fenced:
return
payload = redis_connector.external_group_sync.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
# if redis_connector_index.active():
# return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
try:
payload = redis_connector.external_group_sync.payload
except ValidationError:
task_logger.exception(
"validate_external_group_sync_fence - "
f"Resetting fence in basic state without any activity: fence={fence_key}"
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.external_group_sync.reset()
return
if not payload:
return
if not payload.celery_task_id:
return
# OK, there's actually something for us to validate
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
@@ -527,7 +555,8 @@ def validate_external_group_sync_fence(
"validate_external_group_sync_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
f"fence={fence_key} "
f"payload_id={payload.id}"
)
redis_connector.external_group_sync.reset()

View File

@@ -423,8 +423,8 @@ def connector_indexing_task(
# define a callback class
callback = IndexingCallback(
os.getppid(),
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
redis_connector,
redis_connector_index,
lock,
r,
)

View File

@@ -99,16 +99,16 @@ class IndexingCallback(IndexingHeartbeatInterface):
def __init__(
self,
parent_pid: int,
stop_key: str,
generator_progress_key: str,
redis_connector: RedisConnector,
redis_connector_index: RedisConnectorIndex,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.parent_pid = parent_pid
self.redis_connector: RedisConnector = redis_connector
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
self.redis_lock: RedisLock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
@@ -120,7 +120,7 @@ class IndexingCallback(IndexingHeartbeatInterface):
self.last_parent_check = time.monotonic()
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
if self.redis_connector.stop.fenced:
return True
return False
@@ -143,6 +143,8 @@ class IndexingCallback(IndexingHeartbeatInterface):
# self.last_parent_check = now
try:
self.redis_connector.prune.set_active()
current_time = time.monotonic()
if current_time - self.last_lock_monotonic >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
@@ -165,7 +167,9 @@ class IndexingCallback(IndexingHeartbeatInterface):
redis_lock_dump(self.redis_lock, self.redis_client)
raise
self.redis_client.incrby(self.generator_progress_key, amount)
self.redis_client.incrby(
self.redis_connector_index.generator_progress_key, amount
)
def validate_indexing_fence(

View File

@@ -420,6 +420,7 @@ def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]
- Throughput (docs/min) (only if success)
- Raw start/end times for each sync
"""
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
# Get all sync records that ended in the last hour
@@ -587,6 +588,10 @@ def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]
entity = db_session.scalar(
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
)
else:
# Only user groups and document set sync records have
# an associated entity we can use for latency metrics
continue
if entity is None:
task_logger.error(
@@ -777,7 +782,7 @@ def cloud_check_alembic() -> bool | None:
tenant_to_revision[tenant_id] = result_scalar
except Exception:
task_logger.warning(f"Tenant {tenant_id} has no revision!")
task_logger.error(f"Tenant {tenant_id} has no revision!")
tenant_to_revision[tenant_id] = ALEMBIC_NULL_REVISION
# get the total count of each revision

View File

@@ -1,28 +1,39 @@
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import ValidationError
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import InputType
from onyx.db.connector import mark_ccpair_as_pruned
@@ -35,10 +46,15 @@ from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.search_settings import get_current_search_settings
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.utils import make_short_id
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import pruning_ctx
from onyx.utils.logger import setup_logger
@@ -93,6 +109,8 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
)
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
@@ -123,13 +141,28 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
if not _is_pruning_due(cc_pair):
continue
tasks_created = try_creating_prune_generator_task(
payload_id = try_creating_prune_generator_task(
self.app, cc_pair, db_session, r, tenant_id
)
if not tasks_created:
if not payload_id:
continue
task_logger.info(f"Pruning queued: cc_pair={cc_pair.id}")
task_logger.info(
f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}"
)
# we want to run this less frequently than the overall task
lock_beat.reacquire()
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_PRUNING_FENCES):
# clear any permission fences that don't have associated celery tasks in progress
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
except Exception:
task_logger.exception("Exception while validating pruning fences")
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PRUNING_FENCES, 1, ex=300)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -149,7 +182,7 @@ def try_creating_prune_generator_task(
db_session: Session,
r: Redis,
tenant_id: str | None,
) -> int | None:
) -> str | None:
"""Checks for any conditions that should block the pruning generator task from being
created, then creates the task.
@@ -168,7 +201,7 @@ def try_creating_prune_generator_task(
# we need to serialize starting pruning since it can be triggered either via
# celery beat or manually (API call)
lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_prune_generator_task",
timeout=LOCK_TIMEOUT,
)
@@ -200,7 +233,30 @@ def try_creating_prune_generator_task(
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
celery_app.send_task(
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
try:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair.id,
sync_type=SyncType.PRUNING,
)
except Exception:
task_logger.exception("insert_sync_record exceptioned.")
# signal active before the fence is set
redis_connector.prune.set_active()
# set a basic fence to start
payload = RedisConnectorPrunePayload(
id=make_short_id(),
submitted=datetime.now(timezone.utc),
started=None,
celery_task_id=None,
)
redis_connector.prune.set_fence(payload)
result = celery_app.send_task(
OnyxCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
kwargs=dict(
cc_pair_id=cc_pair.id,
@@ -213,16 +269,11 @@ def try_creating_prune_generator_task(
priority=OnyxCeleryPriority.LOW,
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=cc_pair.id,
sync_type=SyncType.PRUNING,
)
# fill in the celery task id
payload.celery_task_id = result.id
redis_connector.prune.set_fence(payload)
# set this only after all tasks have been added
redis_connector.prune.set_fence(True)
payload_id = payload.id
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
return None
@@ -230,7 +281,7 @@ def try_creating_prune_generator_task(
if lock.owned():
lock.release()
return 1
return payload_id
@shared_task(
@@ -252,6 +303,8 @@ def connector_pruning_generator_task(
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
payload_id: str | None = None
LoggerContextVars.reset()
pruning_ctx_dict = pruning_ctx.get()
@@ -265,6 +318,46 @@ def connector_pruning_generator_task(
r = get_redis_client(tenant_id=tenant_id)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
# before the primary worker can finalize the fence
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
f"connector_prune_generator_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.prune.fence_key}"
)
if not redis_connector.prune.fenced: # The fence must exist
raise ValueError(
f"connector_prune_generator_task - fence not found: "
f"fence={redis_connector.prune.fence_key}"
)
payload = redis_connector.prune.payload # The payload must exist
if not payload:
raise ValueError(
"connector_prune_generator_task: payload invalid or not found"
)
if payload.celery_task_id is None:
logger.info(
f"connector_prune_generator_task - Waiting for fence: "
f"fence={redis_connector.prune.fence_key}"
)
time.sleep(1)
continue
payload_id = payload.id
logger.info(
f"connector_prune_generator_task - Fence found, continuing...: "
f"fence={redis_connector.prune.fence_key} "
f"payload_id={payload.id}"
)
break
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
@@ -294,6 +387,18 @@ def connector_pruning_generator_task(
)
return
payload = redis_connector.prune.payload
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
new_payload = RedisConnectorPrunePayload(
id=payload.id,
submitted=payload.submitted,
started=datetime.now(timezone.utc),
celery_task_id=payload.celery_task_id,
)
redis_connector.prune.set_fence(new_payload)
task_logger.info(
f"Pruning generator running connector: "
f"cc_pair={cc_pair_id} "
@@ -307,10 +412,13 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
search_settings = get_current_search_settings(db_session)
redis_connector_index = redis_connector.new_index(search_settings.id)
callback = IndexingCallback(
0,
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
redis_connector,
redis_connector_index,
lock,
r,
)
@@ -357,7 +465,9 @@ def connector_pruning_generator_task(
redis_connector.prune.generator_complete = tasks_generated
except Exception as e:
task_logger.exception(
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
f"Pruning exceptioned: cc_pair={cc_pair_id} "
f"connector={connector_id} "
f"payload_id={payload_id}"
)
redis_connector.prune.reset()
@@ -366,7 +476,9 @@ def connector_pruning_generator_task(
if lock.owned():
lock.release()
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")
task_logger.info(
f"Pruning generator finished: cc_pair={cc_pair_id} payload_id={payload_id}"
)
"""Monitoring pruning utils, called in monitor_vespa_sync"""
@@ -415,4 +527,184 @@ def monitor_ccpair_pruning_taskset(
redis_connector.prune.taskset_clear()
redis_connector.prune.generator_clear()
redis_connector.prune.set_fence(False)
redis_connector.prune.set_fence(None)
def validate_pruning_fences(
tenant_id: str | None,
r: Redis,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
# building lookup table can be expensive, so we won't bother
# validating until the queue is small
PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN:
return
# the queue for a single pruning generator task
reserved_generator_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
)
# the queue for a reasonably large set of lightweight deletion tasks
queued_upsert_tasks = celery_get_queued_task_ids(
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
)
# Use replica for this because the worst thing that happens
# is that we don't run the validation on this pass
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")
if not key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
continue
validate_pruning_fence(
tenant_id,
key_bytes,
reserved_generator_tasks,
queued_upsert_tasks,
r,
r_celery,
)
lock_beat.reacquire()
return
def validate_pruning_fence(
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
queued_tasks: set[str],
r: Redis,
r_celery: Redis,
) -> None:
"""See validate_indexing_fence for an overall idea of validation flows.
queued_tasks: the celery queue of lightweight permission sync tasks
reserved_tasks: prefetched tasks for sync task generator
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_pruning_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.prune.fenced:
return
# in the cloud, the payload format may have changed ...
# it's a little sloppy, but just reset the fence for now if that happens
# TODO: add intentional cleanup/abort logic
try:
payload = redis_connector.prune.payload
except ValidationError:
task_logger.exception(
"validate_pruning_fence - "
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.prune.reset()
return
if not payload:
return
if not payload.celery_task_id:
return
# OK, there's actually something for us to validate
# either the generator task must be in flight or its subtasks must be
found = celery_find_task(
payload.celery_task_id,
OnyxCeleryQueues.CONNECTOR_PRUNING,
r_celery,
)
if found:
# the celery task exists in the redis queue
redis_connector.prune.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within a worker
redis_connector.prune.set_active()
return
# look up every task in the current taskset in the celery queue
# every entry in the taskset should have an associated entry in the celery task queue
# because we get the celery tasks first, the entries in our own pruning taskset
# should be roughly a subset of the tasks in celery
# this check isn't very exact, but should be sufficient over a period of time
# A single successful check over some number of attempts is sufficient.
# TODO: if the number of tasks in celery is much lower than than the taskset length
# we might be able to shortcut the lookup since by definition some of the tasks
# must not exist in celery.
tasks_scanned = 0
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
for member in r.sscan_iter(redis_connector.prune.taskset_key):
tasks_scanned += 1
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
continue
if member_str in reserved_tasks:
continue
tasks_not_in_celery += 1
task_logger.info(
"validate_pruning_fence task check: "
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
)
# we're active if there are still tasks to run and those tasks all exist in celery
if tasks_scanned > 0 and tasks_not_in_celery == 0:
redis_connector.prune.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector.prune.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
task_logger.warning(
"validate_pruning_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key} "
f"payload_id={payload.id}"
)
redis_connector.prune.reset()
return

View File

@@ -339,11 +339,15 @@ def try_generate_document_set_sync_tasks(
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
)
try:
insert_sync_record(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
)
except Exception:
task_logger.exception("insert_sync_record exceptioned.")
# set this only after all tasks have been added
rds.set_fence(tasks_generated)
return tasks_generated
@@ -411,11 +415,15 @@ def try_generate_user_group_sync_tasks(
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
)
try:
insert_sync_record(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
)
except Exception:
task_logger.exception("insert_sync_record exceptioned.")
# set this only after all tasks have been added
rug.set_fence(tasks_generated)
@@ -904,7 +912,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
# use a lookup table to find active fences. We still have to verify the fence
# exists since it is an optimization and not the source of truth.
keys = cast(set[Any], r.smembers(OnyxRedisConstants.ACTIVE_FENCES))
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)

View File

@@ -140,6 +140,7 @@ def build_citations_user_message(
context_docs: list[LlmDoc] | list[InferenceChunk],
all_doc_useful: bool,
history_message: str = "",
context_type: str = "context documents",
) -> HumanMessage:
multilingual_expansion = get_multilingual_expansion()
task_prompt_with_reminder = build_task_prompt_reminders(
@@ -156,6 +157,7 @@ def build_citations_user_message(
optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT
user_prompt = CITATIONS_PROMPT.format(
context_type=context_type,
optional_ignore_statement=optional_ignore,
context_docs_str=context_docs_str,
task_prompt=task_prompt_with_reminder,
@@ -165,6 +167,7 @@ def build_citations_user_message(
else:
# if no context docs provided, assume we're in the tool calling flow
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
context_type=context_type,
task_prompt=task_prompt_with_reminder,
user_query=query,
history_block=history_block,

View File

@@ -13,21 +13,6 @@ AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = 30 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = 10 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = 3 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = 12 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = 8 # in seconds
#####
# Agent Configs
#####
@@ -92,76 +77,4 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
or AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH
) # 2000
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION
) # 25
AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
) # 3
AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION
) # 30
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION
) # 8
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
) # 12
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION
) # 25
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
) # 25
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
) # 8
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION
) # 6
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION
) # 1
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION
) # 4
AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
) # 8
GRAPH_VERSION_NAME: str = "a"

View File

@@ -324,6 +324,7 @@ class OnyxRedisSignals:
BLOCK_VALIDATE_PERMISSION_SYNC_FENCES = (
"signal:block_validate_permission_sync_fences"
)
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
@@ -345,6 +346,9 @@ ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud"
# the tenant id we use for system level redis operations
ONYX_CLOUD_TENANT_ID = "cloud"
# the redis namespace for runtime variables
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
class OnyxCeleryTask:
DEFAULT = "celery"

View File

@@ -65,10 +65,25 @@ class AirtableConnector(LoadConnector):
base_id: str,
table_name_or_id: str,
treat_all_non_attachment_fields_as_metadata: bool = False,
view_id: str | None = None,
share_id: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
"""Initialize an AirtableConnector.
Args:
base_id: The ID of the Airtable base to connect to
table_name_or_id: The name or ID of the table to index
treat_all_non_attachment_fields_as_metadata: If True, all fields except attachments will be treated as metadata.
If False, only fields with types in DEFAULT_METADATA_FIELD_TYPES will be treated as metadata.
view_id: Optional ID of a specific view to use
share_id: Optional ID of a "share" to use for generating record URLs (https://airtable.com/developers/web/api/list-shares)
batch_size: Number of records to process in each batch
"""
self.base_id = base_id
self.table_name_or_id = table_name_or_id
self.view_id = view_id
self.share_id = share_id
self.batch_size = batch_size
self._airtable_client: AirtableApi | None = None
self.treat_all_non_attachment_fields_as_metadata = (
@@ -85,6 +100,39 @@ class AirtableConnector(LoadConnector):
raise AirtableClientNotSetUpError()
return self._airtable_client
@classmethod
def _get_record_url(
cls,
base_id: str,
table_id: str,
record_id: str,
share_id: str | None,
view_id: str | None,
field_id: str | None = None,
attachment_id: str | None = None,
) -> str:
"""Constructs the URL for a record, optionally including field and attachment IDs
Full possible structure is:
https://airtable.com/BASE_ID/SHARE_ID/TABLE_ID/VIEW_ID/RECORD_ID/FIELD_ID/ATTACHMENT_ID
"""
# If we have a shared link, use that view for better UX
if share_id:
base_url = f"https://airtable.com/{base_id}/{share_id}/{table_id}"
else:
base_url = f"https://airtable.com/{base_id}/{table_id}"
if view_id:
base_url = f"{base_url}/{view_id}"
base_url = f"{base_url}/{record_id}"
if field_id and attachment_id:
return f"{base_url}/{field_id}/{attachment_id}?blocks=hide"
return base_url
def _extract_field_values(
self,
field_id: str,
@@ -110,8 +158,10 @@ class AirtableConnector(LoadConnector):
if field_type == "multipleRecordLinks":
return []
# default link to use for non-attachment fields
default_link = f"https://airtable.com/{base_id}/{table_id}/{record_id}"
# Get the base URL for this record
default_link = self._get_record_url(
base_id, table_id, record_id, self.share_id, self.view_id or view_id
)
if field_type == "multipleAttachments":
attachment_texts: list[tuple[str, str]] = []
@@ -165,17 +215,16 @@ class AirtableConnector(LoadConnector):
extension=file_ext,
)
if attachment_text:
# slightly nicer loading experience if we can specify the view ID
if view_id:
attachment_link = (
f"https://airtable.com/{base_id}/{table_id}/{view_id}/{record_id}"
f"/{field_id}/{attachment_id}?blocks=hide"
)
else:
attachment_link = (
f"https://airtable.com/{base_id}/{table_id}/{record_id}"
f"/{field_id}/{attachment_id}?blocks=hide"
)
# Use the helper method to construct attachment URLs
attachment_link = self._get_record_url(
base_id,
table_id,
record_id,
self.share_id,
self.view_id or view_id,
field_id,
attachment_id,
)
attachment_texts.append(
(f"{filename}:\n{attachment_text}", attachment_link)
)

View File

@@ -27,6 +27,7 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -319,6 +320,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
doc_metadata_list: list[SlimDocument] = []
@@ -386,4 +388,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
if callback:
if callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
)
callback.progress("retrieve_all_slim_documents", 1)
yield doc_metadata_list

View File

@@ -30,6 +30,7 @@ from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -321,6 +322,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
self,
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
query = _build_time_range_query(time_range_start, time_range_end)
doc_batch = []
@@ -343,6 +345,15 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
if len(doc_batch) > SLIM_BATCH_SIZE:
yield doc_batch
doc_batch = []
if callback:
if callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
)
callback.progress("retrieve_all_slim_documents", 1)
if doc_batch:
yield doc_batch
@@ -368,9 +379,10 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
try:
yield from self._fetch_slim_threads(start, end)
yield from self._fetch_slim_threads(start, end, callback=callback)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e

View File

@@ -42,6 +42,7 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -564,6 +565,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
slim_batch = []
for file in self._fetch_drive_items(
@@ -576,15 +578,26 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if len(slim_batch) >= SLIM_BATCH_SIZE:
yield slim_batch
slim_batch = []
if callback:
if callback.should_stop():
raise RuntimeError(
"_extract_slim_docs_from_google_drive: Stop signal detected"
)
callback.progress("_extract_slim_docs_from_google_drive", 1)
yield slim_batch
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
try:
yield from self._extract_slim_docs_from_google_drive(start, end)
yield from self._extract_slim_docs_from_google_drive(
start, end, callback=callback
)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e

View File

@@ -7,6 +7,7 @@ from pydantic import BaseModel
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
SecondsSinceUnixEpoch = float
@@ -63,6 +64,7 @@ class SlimConnector(BaseConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
raise NotImplementedError

View File

@@ -29,6 +29,7 @@ from onyx.connectors.onyx_jira.utils import build_jira_url
from onyx.connectors.onyx_jira.utils import extract_jira_project
from onyx.connectors.onyx_jira.utils import extract_text_from_adf
from onyx.connectors.onyx_jira.utils import get_comment_strs
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -245,6 +246,7 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
jql = f"project = {self.quoted_jira_project}"

View File

@@ -21,6 +21,7 @@ from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -176,6 +177,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
doc_metadata_list: list[SlimDocument] = []
for parent_object_type in self.parent_object_list:

View File

@@ -21,6 +21,7 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -242,6 +243,7 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
slim_doc_batch: list[SlimDocument] = []
for post_id in get_all_post_ids(self.slab_bot_token):

View File

@@ -27,6 +27,7 @@ from onyx.connectors.slack.utils import get_message_link
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
from onyx.connectors.slack.utils import SlackTextCleaner
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -98,6 +99,7 @@ def get_channel_messages(
channel: dict[str, Any],
oldest: str | None = None,
latest: str | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> Generator[list[MessageType], None, None]:
"""Get all messages in a channel"""
# join so that the bot can access messages
@@ -115,6 +117,11 @@ def get_channel_messages(
oldest=oldest,
latest=latest,
):
if callback:
if callback.should_stop():
raise RuntimeError("get_channel_messages: Stop signal detected")
callback.progress("get_channel_messages", 0)
yield cast(list[MessageType], result["messages"])
@@ -325,6 +332,7 @@ def _get_all_doc_ids(
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
"""
Get all document ids in the workspace, channel by channel
@@ -342,6 +350,7 @@ def _get_all_doc_ids(
channel_message_batches = get_channel_messages(
client=client,
channel=channel,
callback=callback,
)
message_ts_set: set[str] = set()
@@ -390,6 +399,7 @@ class SlackPollConnector(PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")
@@ -398,6 +408,7 @@ class SlackPollConnector(PollConnector, SlimConnector):
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
callback=callback,
)
def poll_source(

View File

@@ -39,19 +39,6 @@ def get_message_link(
return permalink
def _make_slack_api_call_logged(
call: Callable[..., SlackResponse],
) -> Callable[..., SlackResponse]:
@wraps(call)
def logged_call(**kwargs: Any) -> SlackResponse:
logger.debug(f"Making call to Slack API '{call.__name__}' with args '{kwargs}'")
result = call(**kwargs)
logger.debug(f"Call to Slack API '{call.__name__}' returned '{result}'")
return result
return logged_call
def _make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., Generator[dict[str, Any], None, None]]:
@@ -127,18 +114,14 @@ def make_slack_api_rate_limited(
def make_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)(**kwargs)
return basic_retry_wrapper(make_slack_api_rate_limited(call))(**kwargs)
def make_paginated_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return _make_slack_api_call_paginated(
basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)
basic_retry_wrapper(make_slack_api_rate_limited(call))
)(**kwargs)

View File

@@ -20,6 +20,7 @@ from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.retry_wrapper import retry_builder
@@ -405,6 +406,7 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
slim_doc_batch: list[SlimDocument] = []
if self.content_type == "articles":

View File

@@ -152,7 +152,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if not specified, all assistants are shown
temperature_override_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
chosen_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
)

View File

@@ -204,6 +204,14 @@ def create_update_persona(
if not all_prompt_ids:
raise ValueError("No prompt IDs provided")
# Default persona validation
if create_persona_request.is_default_persona:
if not create_persona_request.is_public:
raise ValueError("Cannot make a default persona non public")
if user and user.role != UserRole.ADMIN:
raise ValueError("Only admins can make a default persona")
persona = upsert_persona(
persona_id=persona_id,
user=user,
@@ -228,6 +236,7 @@ def create_update_persona(
num_chunks=create_persona_request.num_chunks,
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
is_default_persona=create_persona_request.is_default_persona,
)
versioned_make_persona_private = fetch_versioned_implementation(
@@ -509,6 +518,7 @@ def upsert_persona(
existing_persona.is_visible = is_visible
existing_persona.search_start_date = search_start_date
existing_persona.labels = labels or []
existing_persona.is_default_persona = is_default_persona
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
@@ -589,6 +599,23 @@ def delete_old_default_personas(
db_session.commit()
def update_persona_is_default(
persona_id: int,
is_default: bool,
db_session: Session,
user: User | None = None,
) -> None:
persona = fetch_persona_by_id_for_user(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
if not persona.is_public:
persona.is_public = True
persona.is_default_persona = is_default
db_session.commit()
def update_persona_visibility(
persona_id: int,
is_visible: bool,

View File

@@ -6,6 +6,7 @@ from fastapi import HTTPException
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from sqlalchemy.sql import expression
from sqlalchemy.sql.elements import ColumnElement
@@ -274,7 +275,7 @@ def _generate_ext_permissioned_user(email: str) -> User:
def batch_add_ext_perm_user_if_not_exists(
db_session: Session, emails: list[str]
db_session: Session, emails: list[str], continue_on_error: bool = False
) -> list[User]:
lower_emails = [email.lower() for email in emails]
found_users, missing_lower_emails = _get_users_by_emails(db_session, lower_emails)
@@ -283,10 +284,23 @@ def batch_add_ext_perm_user_if_not_exists(
for email in missing_lower_emails:
new_users.append(_generate_ext_permissioned_user(email=email))
db_session.add_all(new_users)
db_session.commit()
return found_users + new_users
try:
db_session.add_all(new_users)
db_session.commit()
except IntegrityError:
db_session.rollback()
if not continue_on_error:
raise
for user in new_users:
try:
db_session.add(user)
db_session.commit()
except IntegrityError:
db_session.rollback()
continue
# Fetch all users again to ensure we have the most up-to-date list
all_users, _ = _get_users_by_emails(db_session, lower_emails)
return all_users
def delete_user_from_db(

View File

@@ -17,6 +17,7 @@ from uuid import UUID
import httpx # type: ignore
import requests # type: ignore
from retry import retry
from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.configs.chat_configs import NUM_RETURNED_HITS
@@ -549,6 +550,11 @@ class VespaIndex(DocumentIndex):
time.monotonic() - update_start,
)
@retry(
tries=3,
delay=1,
backoff=2,
)
def _update_single_chunk(
self,
doc_chunk_id: UUID,
@@ -559,6 +565,7 @@ class VespaIndex(DocumentIndex):
) -> None:
"""
Update a single "chunk" (document) in Vespa using its chunk ID.
Retries if we encounter transient HTTPStatusError (e.g., overload).
"""
update_dict: dict[str, dict] = {"fields": {}}
@@ -567,13 +574,11 @@ class VespaIndex(DocumentIndex):
update_dict["fields"][BOOST] = {"assign": fields.boost}
if fields.document_sets is not None:
# WeightedSet<string> needs a map { item: weight, ... }
update_dict["fields"][DOCUMENT_SETS] = {
"assign": {document_set: 1 for document_set in fields.document_sets}
}
if fields.access is not None:
# Similar to above
update_dict["fields"][ACCESS_CONTROL_LIST] = {
"assign": {acl_entry: 1 for acl_entry in fields.access.to_acl()}
}
@@ -585,7 +590,10 @@ class VespaIndex(DocumentIndex):
logger.error("Update request received but nothing to update.")
return
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}?create=true"
vespa_url = (
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}"
"?create=true"
)
try:
resp = http_client.put(
@@ -595,8 +603,11 @@ class VespaIndex(DocumentIndex):
)
resp.raise_for_status()
except httpx.HTTPStatusError as e:
error_message = f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). Details: {e.response.text}"
logger.error(error_message)
logger.error(
f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). "
f"Details: {e.response.text}"
)
# Re-raise so the @retry decorator will catch and retry
raise
def update_single(

View File

@@ -146,6 +146,23 @@ def _index_vespa_chunk(
title = document.get_title_for_document_index()
metadata_json = document.metadata
cleaned_metadata_json: dict[str, str | list[str]] = {}
for key, value in metadata_json.items():
cleaned_key = remove_invalid_unicode_chars(key)
if isinstance(value, list):
cleaned_metadata_json[cleaned_key] = [
remove_invalid_unicode_chars(item) for item in value
]
else:
cleaned_metadata_json[cleaned_key] = remove_invalid_unicode_chars(value)
metadata_list = document.get_metadata_str_attributes()
if metadata_list:
metadata_list = [
remove_invalid_unicode_chars(metadata) for metadata in metadata_list
]
vespa_document_fields = {
DOCUMENT_ID: document.id,
CHUNK_ID: chunk.chunk_id,
@@ -166,10 +183,10 @@ def _index_vespa_chunk(
SEMANTIC_IDENTIFIER: remove_invalid_unicode_chars(document.semantic_identifier),
SECTION_CONTINUATION: chunk.section_continuation,
LARGE_CHUNK_REFERENCE_IDS: chunk.large_chunk_reference_ids,
METADATA: json.dumps(document.metadata),
METADATA: json.dumps(cleaned_metadata_json),
# Save as a list for efficient extraction as an Attribute
METADATA_LIST: chunk.source_document.get_metadata_str_attributes(),
METADATA_SUFFIX: chunk.metadata_suffix_keyword,
METADATA_LIST: metadata_list,
METADATA_SUFFIX: remove_invalid_unicode_chars(chunk.metadata_suffix_keyword),
EMBEDDINGS: embeddings_name_vector_map,
TITLE_EMBEDDING: chunk.title_embedding,
DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at),

View File

@@ -27,6 +27,7 @@ from langchain_core.prompt_values import PromptValue
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
from onyx.configs.chat_configs import QA_TIMEOUT
from onyx.configs.model_configs import (
DISABLE_LITELLM_STREAMING,
)
@@ -35,6 +36,7 @@ from onyx.configs.model_configs import LITELLM_EXTRA_BODY
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.utils import model_is_reasoning_model
from onyx.server.utils import mask_string
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
@@ -50,18 +52,6 @@ litellm.telemetry = False
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
class LLMTimeoutError(Exception):
"""
Exception raised when an LLM call times out.
"""
class LLMRateLimitError(Exception):
"""
Exception raised when an LLM call is rate limited.
"""
def _base_msg_to_role(msg: BaseMessage) -> str:
if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk):
return "user"
@@ -241,15 +231,15 @@ class DefaultMultiLLM(LLM):
def __init__(
self,
api_key: str | None,
timeout: int,
model_provider: str,
model_name: str,
timeout: int | None = None,
api_base: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
max_output_tokens: int | None = None,
custom_llm_provider: str | None = None,
temperature: float = GEN_AI_TEMPERATURE,
temperature: float | None = None,
custom_config: dict[str, str] | None = None,
extra_headers: dict[str, str] | None = None,
extra_body: dict | None = LITELLM_EXTRA_BODY,
@@ -257,9 +247,16 @@ class DefaultMultiLLM(LLM):
long_term_logger: LongTermLogger | None = None,
):
self._timeout = timeout
if timeout is None:
if model_is_reasoning_model(model_name):
self._timeout = QA_TIMEOUT * 10 # Reasoning models are slow
else:
self._timeout = QA_TIMEOUT
self._temperature = GEN_AI_TEMPERATURE if temperature is None else temperature
self._model_provider = model_provider
self._model_version = model_name
self._temperature = temperature
self._api_key = api_key
self._deployment_name = deployment_name
self._api_base = api_base
@@ -392,7 +389,6 @@ class DefaultMultiLLM(LLM):
tool_choice: ToolChoiceOptions | None,
stream: bool,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
# to a dict representation
@@ -418,7 +414,7 @@ class DefaultMultiLLM(LLM):
stream=stream,
# model params
temperature=0,
timeout=timeout_override or self._timeout,
timeout=self._timeout,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
@@ -437,12 +433,6 @@ class DefaultMultiLLM(LLM):
except Exception as e:
self._record_error(processed_prompt, e)
# for break pointing
if isinstance(e, litellm.Timeout):
raise LLMTimeoutError(e)
elif isinstance(e, litellm.RateLimitError):
raise LLMRateLimitError(e)
raise e
@property
@@ -463,7 +453,6 @@ class DefaultMultiLLM(LLM):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> BaseMessage:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
@@ -471,12 +460,7 @@ class DefaultMultiLLM(LLM):
response = cast(
litellm.ModelResponse,
self._completion(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=False,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
prompt, tools, tool_choice, False, structured_response_format
),
)
choice = response.choices[0]
@@ -494,31 +478,19 @@ class DefaultMultiLLM(LLM):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> Iterator[BaseMessage]:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
if DISABLE_LITELLM_STREAMING:
yield self.invoke(
prompt,
tools,
tool_choice,
structured_response_format,
timeout_override,
)
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
return
output = None
response = cast(
litellm.CustomStreamWrapper,
self._completion(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=True,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
prompt, tools, tool_choice, True, structured_response_format
),
)
try:

View File

@@ -81,7 +81,6 @@ class CustomModelServer(LLM):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> BaseMessage:
return self._execute(prompt)
@@ -91,6 +90,5 @@ class CustomModelServer(LLM):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> Iterator[BaseMessage]:
yield self._execute(prompt)

View File

@@ -2,7 +2,6 @@ from typing import Any
from onyx.chat.models import PersonaOverrideConfig
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.chat_configs import QA_TIMEOUT
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.db.engine import get_session_context_manager
@@ -88,8 +87,8 @@ def get_llms_for_persona(
def get_default_llms(
timeout: int = QA_TIMEOUT,
temperature: float = GEN_AI_TEMPERATURE,
timeout: int | None = None,
temperature: float | None = None,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> tuple[LLM, LLM]:
@@ -138,7 +137,7 @@ def get_llm(
api_version: str | None = None,
custom_config: dict[str, str] | None = None,
temperature: float | None = None,
timeout: int = QA_TIMEOUT,
timeout: int | None = None,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> LLM:

View File

@@ -90,13 +90,12 @@ class LLM(abc.ABC):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._invoke_implementation(
prompt, tools, tool_choice, structured_response_format, timeout_override
prompt, tools, tool_choice, structured_response_format
)
@abc.abstractmethod
@@ -106,7 +105,6 @@ class LLM(abc.ABC):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> BaseMessage:
raise NotImplementedError
@@ -116,13 +114,12 @@ class LLM(abc.ABC):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
messages = self._stream_implementation(
prompt, tools, tool_choice, structured_response_format, timeout_override
prompt, tools, tool_choice, structured_response_format
)
tokens = []
@@ -141,6 +138,5 @@ class LLM(abc.ABC):
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError

View File

@@ -29,11 +29,11 @@ OPENAI_PROVIDER_NAME = "openai"
OPEN_AI_MODEL_NAMES = [
"o3-mini",
"o1-mini",
"o1-preview",
"o1-2024-12-17",
"o1",
"gpt-4",
"gpt-4o",
"gpt-4o-mini",
"o1-preview",
"gpt-4-turbo",
"gpt-4-turbo-preview",
"gpt-4-1106-preview",

View File

@@ -543,3 +543,14 @@ def model_supports_image_input(model_name: str, model_provider: str) -> bool:
f"Failed to get model object for {model_provider}/{model_name}"
)
return False
def model_is_reasoning_model(model_name: str) -> bool:
_REASONING_MODEL_NAMES = [
"o1",
"o1-mini",
"o3-mini",
"deepseek-reasoner",
"deepseek-r1",
]
return model_name.lower() in _REASONING_MODEL_NAMES

View File

@@ -5,6 +5,8 @@ UNKNOWN_ANSWER = "I do not have enough information to answer this question."
NO_RECOVERED_DOCS = "No relevant information recovered"
YES = "yes"
NO = "no"
# Framing/Support/Template Prompts
HISTORY_FRAMING_PROMPT = f"""
For more context, here is the history of the conversation so far that preceded this question:

View File

@@ -91,7 +91,7 @@ SAMPLE RESPONSE:
# similar to the chat flow, but with the option of including a
# "conversation history" block
CITATIONS_PROMPT = f"""
Refer to the following context documents when responding to me.{DEFAULT_IGNORE_STATEMENT}
Refer to the following {{context_type}} when responding to me.{DEFAULT_IGNORE_STATEMENT}
CONTEXT:
{GENERAL_SEP_PAT}
@@ -108,7 +108,7 @@ CONTEXT:
# NOTE: need to add the extra line about "getting right to the point" since the
# tool calling models from OpenAI tend to be more verbose
CITATIONS_PROMPT_FOR_TOOL_CALLING = f"""
Refer to the provided context documents when responding to me.{DEFAULT_IGNORE_STATEMENT} \
Refer to the provided {{context_type}} when responding to me.{DEFAULT_IGNORE_STATEMENT} \
You should always get right to the point, and never use extraneous language.
{{history_block}}{{task_prompt}}

View File

@@ -80,7 +80,8 @@ class RedisConnectorPermissionSync:
def get_active_task_count(self) -> int:
"""Count of active permission sync tasks"""
count = 0
for _ in self.redis.scan_iter(
for _ in self.redis.sscan_iter(
OnyxRedisConstants.ACTIVE_FENCES,
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from typing import Any
from typing import cast
import redis
@@ -8,10 +7,12 @@ from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.configs.constants import OnyxRedisConstants
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorExternalGroupSyncPayload(BaseModel):
id: str
submitted: datetime
started: datetime | None
celery_task_id: str | None
@@ -37,6 +38,12 @@ class RedisConnectorExternalGroupSync:
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorexternalgroupsync_taskset
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorexternalgroupsync+sub
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
@@ -50,6 +57,7 @@ class RedisConnectorExternalGroupSync:
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
@@ -66,7 +74,8 @@ class RedisConnectorExternalGroupSync:
def get_active_task_count(self) -> int:
"""Count of active external group syncing tasks"""
count = 0
for _ in self.redis.scan_iter(
for _ in self.redis.sscan_iter(
OnyxRedisConstants.ACTIVE_FENCES,
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
@@ -83,10 +92,11 @@ class RedisConnectorExternalGroupSync:
@property
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(Any, self.redis.get(self.fence_key))
if fence_bytes is None:
fence_raw = self.redis.get(self.fence_key)
if fence_raw is None:
return None
fence_bytes = cast(bytes, fence_raw)
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorExternalGroupSyncPayload.model_validate_json(
cast(str, fence_str)
@@ -99,10 +109,26 @@ class RedisConnectorExternalGroupSync:
payload: RedisConnectorExternalGroupSyncPayload | None,
) -> None:
if not payload:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, payload.model_dump_json())
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
def set_active(self) -> None:
"""This sets a signal to keep the permissioning flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
@property
def generator_complete(self) -> int | None:
@@ -138,6 +164,8 @@ class RedisConnectorExternalGroupSync:
pass
def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.active_key)
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
@@ -152,6 +180,9 @@ class RedisConnectorExternalGroupSync:
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorExternalGroupSync.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorExternalGroupSync.TASKSET_PREFIX + "*"):
r.delete(key)

View File

@@ -1,9 +1,11 @@
import time
from datetime import datetime
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@@ -16,6 +18,13 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorPrunePayload(BaseModel):
id: str
submitted: datetime
started: datetime | None
celery_task_id: str | None
class RedisConnectorPrune:
"""Manages interactions with redis for pruning tasks. Should only be accessed
through RedisConnector."""
@@ -36,6 +45,12 @@ class RedisConnectorPrune:
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
@@ -49,6 +64,7 @@ class RedisConnectorPrune:
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
@@ -65,8 +81,10 @@ class RedisConnectorPrune:
def get_active_task_count(self) -> int:
"""Count of active pruning tasks"""
count = 0
for key in self.redis.scan_iter(
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
for _ in self.redis.sscan_iter(
OnyxRedisConstants.ACTIVE_FENCES,
RedisConnectorPrune.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
count += 1
return count
@@ -78,15 +96,44 @@ class RedisConnectorPrune:
return False
def set_fence(self, value: bool) -> None:
if not value:
@property
def payload(self) -> RedisConnectorPrunePayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorPrunePayload.model_validate_json(cast(str, fence_str))
return payload
def set_fence(
self,
payload: RedisConnectorPrunePayload | None,
) -> None:
if not payload:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, 0)
self.redis.set(self.fence_key, payload.model_dump_json())
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
def set_active(self) -> None:
"""This sets a signal to keep the permissioning flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
@property
def generator_complete(self) -> int | None:
"""the fence payload is an int representing the starting number of
@@ -162,6 +209,7 @@ class RedisConnectorPrune:
def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.active_key)
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
@@ -176,6 +224,9 @@ class RedisConnectorPrune:
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorPrune.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPrune.TASKSET_PREFIX + "*"):
r.delete(key)

View File

@@ -368,15 +368,17 @@ def prune_cc_pair(
f"credential={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
tasks_created = try_creating_prune_generator_task(
payload_id = try_creating_prune_generator_task(
primary_app, cc_pair, db_session, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
)
if not tasks_created:
if not payload_id:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Pruning task creation failed.",
)
logger.info(f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}")
return StatusResponse(
success=True,
message="Successfully created the pruning task.",
@@ -514,15 +516,17 @@ def sync_cc_pair_groups(
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
tasks_created = try_creating_external_group_sync_task(
payload_id = try_creating_external_group_sync_task(
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
)
if not tasks_created:
if not payload_id:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="External group sync task creation failed.",
)
logger.info(f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}")
return StatusResponse(
success=True,
message="Successfully created the external group sync task.",

View File

@@ -32,6 +32,7 @@ from onyx.db.persona import get_personas_for_user
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import mark_persona_as_not_deleted
from onyx.db.persona import update_all_personas_display_priority
from onyx.db.persona import update_persona_is_default
from onyx.db.persona import update_persona_label
from onyx.db.persona import update_persona_public_status
from onyx.db.persona import update_persona_shared_users
@@ -56,7 +57,6 @@ from onyx.tools.utils import is_image_generation_available
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
logger = setup_logger()
@@ -72,6 +72,10 @@ class IsPublicRequest(BaseModel):
is_public: bool
class IsDefaultRequest(BaseModel):
is_default_persona: bool
@admin_router.patch("/{persona_id}/visible")
def patch_persona_visibility(
persona_id: int,
@@ -106,6 +110,25 @@ def patch_user_presona_public_status(
raise HTTPException(status_code=403, detail=str(e))
@admin_router.patch("/{persona_id}/default")
def patch_persona_default_status(
persona_id: int,
is_default_request: IsDefaultRequest,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
update_persona_is_default(
persona_id=persona_id,
is_default=is_default_request.is_default_persona,
db_session=db_session,
user=user,
)
except ValueError as e:
logger.exception("Failed to update persona default status")
raise HTTPException(status_code=403, detail=str(e))
@admin_router.put("/display-priority")
def patch_persona_display_priority(
display_priority_request: DisplayPriorityRequest,

View File

@@ -279,4 +279,5 @@ class InternetSearchTool(Tool):
using_tool_calling_llm=using_tool_calling_llm,
answer_style_config=self.answer_style_config,
prompt_config=self.prompt_config,
context_type="internet search results",
)

View File

@@ -25,6 +25,7 @@ def build_next_prompt_for_search_like_tool(
using_tool_calling_llm: bool,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
context_type: str = "context documents",
) -> AnswerPromptBuilder:
if not using_tool_calling_llm:
final_context_docs_response = next(
@@ -58,6 +59,7 @@ def build_next_prompt_for_search_like_tool(
else False
),
history_message=prompt_builder.single_message_history or "",
context_type=context_type,
)
)

View File

@@ -86,7 +86,10 @@ def run_functions_in_parallel(
Executes a list of FunctionCalls in parallel and stores the results in a dictionary where the keys
are the result_id of the FunctionCall and the values are the results of the call.
"""
results = {}
results: dict[str, Any] = {}
if len(function_calls) == 0:
return results
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
future_to_id = {

View File

@@ -9,6 +9,8 @@ from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.models import Document
from onyx.connectors.models import Section
BASE_VIEW_ID = "viwVUEJjWPd8XYjh8"
class AirtableConfig(BaseModel):
base_id: str
@@ -46,6 +48,8 @@ def create_test_document(
days_since_status_change: int | None,
attachments: list[tuple[str, str]] | None = None,
all_fields_as_metadata: bool = False,
share_id: str | None = None,
view_id: str | None = None,
) -> Document:
base_id = os.environ.get("AIRTABLE_TEST_BASE_ID")
table_id = os.environ.get("AIRTABLE_TEST_TABLE_ID")
@@ -60,7 +64,13 @@ def create_test_document(
f"Required environment variables not set: {', '.join(missing_vars)}. "
"These variables are required to run Airtable connector tests."
)
link_base = f"https://airtable.com/{base_id}/{table_id}"
link_base = f"https://airtable.com/{base_id}"
if share_id:
link_base = f"{link_base}/{share_id}"
link_base = f"{link_base}/{table_id}"
if view_id:
link_base = f"{link_base}/{view_id}"
sections = []
if not all_fields_as_metadata:
@@ -214,6 +224,7 @@ def test_airtable_connector_basic(
assignee="Chris Weaver (chris@onyx.app)",
submitted_by="Chris Weaver (chris@onyx.app)",
all_fields_as_metadata=False,
view_id=BASE_VIEW_ID,
),
create_test_document(
id="reccSlIA4pZEFxPBg",
@@ -234,6 +245,7 @@ def test_airtable_connector_basic(
)
],
all_fields_as_metadata=False,
view_id=BASE_VIEW_ID,
),
]
@@ -285,6 +297,81 @@ def test_airtable_connector_all_metadata(
)
],
all_fields_as_metadata=True,
view_id=BASE_VIEW_ID,
),
]
# Compare documents using the utility function
compare_documents(doc_batch, expected_docs)
def test_airtable_connector_with_share_and_view(
mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig
) -> None:
"""Test behavior when using share_id and view_id for URL generation."""
SHARE_ID = "shrkfjEzDmLaDtK83"
connector = AirtableConnector(
base_id=airtable_config.base_id,
table_name_or_id=airtable_config.table_identifier,
treat_all_non_attachment_fields_as_metadata=False,
share_id=SHARE_ID,
view_id=BASE_VIEW_ID,
)
connector.load_credentials(
{
"airtable_access_token": airtable_config.access_token,
}
)
doc_batch_generator = connector.load_from_state()
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 2
expected_docs = [
create_test_document(
id="rec8BnxDLyWeegOuO",
title="Slow Internet",
description="The internet connection is very slow.",
priority="Medium",
status="In Progress",
ticket_id="2",
created_time="2024-12-24T21:02:49.000Z",
status_last_changed="2024-12-24T21:02:49.000Z",
days_since_status_change=0,
assignee="Chris Weaver (chris@onyx.app)",
submitted_by="Chris Weaver (chris@onyx.app)",
all_fields_as_metadata=False,
share_id=SHARE_ID,
view_id=BASE_VIEW_ID,
),
create_test_document(
id="reccSlIA4pZEFxPBg",
title="Printer Issue",
description="The office printer is not working.",
priority="High",
status="Open",
ticket_id="1",
created_time="2024-12-24T21:02:49.000Z",
status_last_changed="2024-12-24T21:02:49.000Z",
days_since_status_change=0,
assignee="Chris Weaver (chris@onyx.app)",
submitted_by="Chris Weaver (chris@onyx.app)",
attachments=[
(
"Test.pdf:\ntesting!!!",
(
f"https://airtable.com/{airtable_config.base_id}/{SHARE_ID}/"
f"{os.environ['AIRTABLE_TEST_TABLE_ID']}/{BASE_VIEW_ID}/reccSlIA4pZEFxPBg/"
"fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide"
),
)
],
all_fields_as_metadata=False,
share_id=SHARE_ID,
view_id=BASE_VIEW_ID,
),
]

View File

@@ -66,7 +66,7 @@ class PersonaManager:
response = requests.post(
f"{API_SERVER_URL}/persona",
json=persona_creation_request.model_dump(),
json=persona_creation_request.model_dump(mode="json"),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
@@ -119,6 +119,7 @@ class PersonaManager:
) -> DATestPersona:
system_prompt = system_prompt or f"System prompt for {persona.name}"
task_prompt = task_prompt or f"Task prompt for {persona.name}"
persona_update_request = PersonaUpsertRequest(
name=name or persona.name,
description=description or persona.description,
@@ -146,7 +147,7 @@ class PersonaManager:
response = requests.patch(
f"{API_SERVER_URL}/persona/{persona.id}",
json=persona_update_request.model_dump(),
json=persona_update_request.model_dump(mode="json"),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,

View File

@@ -58,6 +58,7 @@ def test_persona_permissions(reset: None) -> None:
description="A persona created by basic user",
is_public=False,
groups=[],
users=[admin_user.id],
user_performing_action=basic_user,
)
PersonaManager.verify(basic_user_persona, user_performing_action=basic_user)
@@ -139,9 +140,14 @@ def test_persona_permissions(reset: None) -> None:
"""Test admin permissions"""
# Admin can edit any persona
# the persona was shared with the admin user on creation
# this edit call will simulate having the same user in the list twice.
# The server side should dedupe and handle this correctly (prior bug)
PersonaManager.edit(
persona=basic_user_persona,
description="Updated by admin",
description="Updated by admin 2",
users=[admin_user.id, admin_user.id],
user_performing_action=admin_user,
)
PersonaManager.verify(basic_user_persona, user_performing_action=admin_user)

View File

@@ -23,12 +23,12 @@ _Note:_ if you are having problems accessing the ^, try setting the `WEB_DOMAIN`
`http://127.0.0.1:3000` and accessing it there.
## Testing
This testing process will reset your application into a clean state.
This testing process will reset your application into a clean state.
Don't run these tests if you don't want to do this!
Bring up the entire application.
1. Reset the instance
```cd backend
@@ -59,4 +59,4 @@ may use this for local troubleshooting and testing.
```
cd web
npx chromatic --playwright --project-token={your token here}
```
```

View File

@@ -4,7 +4,7 @@
"rsc": true,
"tsx": true,
"tailwind": {
"config": "tailwind.config.js",
"config": "tailwind-themes/tailwind.config.js",
"css": "src/app/globals.css",
"baseColor": "neutral",
"cssVariables": false,

1001
web/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@
"version-comment": "version field must be SemVer or chromatic will barf",
"private": true,
"scripts": {
"dev": "next dev --turbopack",
"dev": "next dev --turbo",
"build": "next build",
"start": "next start",
"lint": "next lint",
@@ -21,17 +21,17 @@
"@radix-ui/react-accordion": "^1.2.2",
"@radix-ui/react-checkbox": "^1.1.2",
"@radix-ui/react-collapsible": "^1.1.2",
"@radix-ui/react-dialog": "^1.1.2",
"@radix-ui/react-dropdown-menu": "^2.1.4",
"@radix-ui/react-dialog": "^1.1.6",
"@radix-ui/react-dropdown-menu": "^2.1.6",
"@radix-ui/react-label": "^2.1.1",
"@radix-ui/react-popover": "^1.1.2",
"@radix-ui/react-popover": "^1.1.6",
"@radix-ui/react-radio-group": "^1.2.2",
"@radix-ui/react-scroll-area": "^1.2.2",
"@radix-ui/react-select": "^2.1.2",
"@radix-ui/react-select": "^2.1.6",
"@radix-ui/react-separator": "^1.1.0",
"@radix-ui/react-slider": "^1.2.2",
"@radix-ui/react-slot": "^1.1.0",
"@radix-ui/react-switch": "^1.1.1",
"@radix-ui/react-slot": "^1.1.2",
"@radix-ui/react-switch": "^1.1.3",
"@radix-ui/react-tabs": "^1.1.1",
"@radix-ui/react-tooltip": "^1.1.3",
"@sentry/nextjs": "^8.50.0",
@@ -56,6 +56,7 @@
"lucide-react": "^0.454.0",
"mdast-util-find-and-replace": "^3.0.1",
"next": "^15.0.2",
"next-themes": "^0.4.4",
"npm": "^10.8.0",
"postcss": "^8.4.31",
"posthog-js": "^1.176.0",

Binary file not shown.

Before

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.9 KiB

BIN
web/public/discord.webp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.0 KiB

BIN
web/public/litellm.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 89 KiB

BIN
web/public/logo-dark.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

View File

@@ -27,8 +27,12 @@ function SourceTile({
w-40
cursor-pointer
shadow-md
hover:bg-hover
${preSelect ? "bg-hover subtle-pulse" : "bg-hover-light"}
hover:bg-accent-background-hovered
${
preSelect
? "bg-accent-background-hovered subtle-pulse"
: "bg-accent-background"
}
`}
href={sourceMetadata.adminUrl}
>

View File

@@ -56,7 +56,7 @@ function NewApiKeyModal({
<div className="flex mt-2">
<b className="my-auto break-all">{apiKey}</b>
<div
className="ml-2 my-auto p-2 hover:bg-hover rounded cursor-pointer"
className="ml-2 my-auto p-2 hover:bg-accent-background-hovered rounded cursor-pointer"
onClick={() => {
setCopyClicked(true);
navigator.clipboard.writeText(apiKey);
@@ -112,7 +112,10 @@ function Main() {
}
const newApiKeyButton = (
<CreateButton href="/admin/api-key/new" text="Create API Key" />
<CreateButton
onClick={() => setShowCreateUpdateForm(true)}
text="Create API Key"
/>
);
if (apiKeys.length === 0) {
@@ -179,7 +182,7 @@ function Main() {
flex
mb-1
w-fit
hover:bg-hover cursor-pointer
hover:bg-accent-background-hovered cursor-pointer
p-2
rounded-lg
border-border
@@ -203,7 +206,7 @@ function Main() {
flex
mb-1
w-fit
hover:bg-hover cursor-pointer
hover:bg-accent-background-hovered cursor-pointer
p-2
rounded-lg
border-border

View File

@@ -3,7 +3,13 @@
import React from "react";
import { Option } from "@/components/Dropdown";
import { generateRandomIconShape } from "@/lib/assistantIconUtils";
import { CCPairBasicInfo, DocumentSet, User, UserGroup } from "@/lib/types";
import {
CCPairBasicInfo,
DocumentSet,
User,
UserGroup,
UserRole,
} from "@/lib/types";
import { Separator } from "@/components/ui/separator";
import { Button } from "@/components/ui/button";
import { ArrayHelpers, FieldArray, Form, Formik, FormikProps } from "formik";
@@ -33,9 +39,8 @@ import {
TooltipTrigger,
} from "@/components/ui/tooltip";
import Link from "next/link";
import { useRouter } from "next/navigation";
import { useRouter, useSearchParams } from "next/navigation";
import { useEffect, useMemo, useState } from "react";
import { FiInfo } from "react-icons/fi";
import * as Yup from "yup";
import CollapsibleSection from "./CollapsibleSection";
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
@@ -71,11 +76,11 @@ import {
Option as DropdownOption,
} from "@/components/Dropdown";
import { SourceChip } from "@/app/chat/input/ChatInputBar";
import { TagIcon, UserIcon, XIcon } from "lucide-react";
import { TagIcon, UserIcon, XIcon, InfoIcon } from "lucide-react";
import { LLMSelector } from "@/components/llm/LLMSelector";
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
import Title from "@/components/ui/title";
import { SEARCH_TOOL_ID } from "@/app/chat/tools/constants";
@@ -127,6 +132,8 @@ export function AssistantEditor({
}) {
const { refreshAssistants, isImageGenerationAvailable } = useAssistants();
const router = useRouter();
const searchParams = useSearchParams();
const isAdminPage = searchParams.get("admin") === "true";
const { popup, setPopup } = usePopup();
const { labels, refreshLabels, createLabel, updateLabel, deleteLabel } =
@@ -216,6 +223,8 @@ export function AssistantEditor({
enabledToolsMap[tool.id] = personaCurrentToolIds.includes(tool.id);
});
const [showVisibilityWarning, setShowVisibilityWarning] = useState(false);
const initialValues = {
name: existingPersona?.name ?? "",
description: existingPersona?.description ?? "",
@@ -252,6 +261,7 @@ export function AssistantEditor({
(u) => u.id !== existingPersona.owner?.id
) ?? [],
selectedGroups: existingPersona?.groups ?? [],
is_default_persona: existingPersona?.is_default_persona ?? false,
};
interface AssistantPrompt {
@@ -308,24 +318,12 @@ export function AssistantEditor({
const [isRequestSuccessful, setIsRequestSuccessful] = useState(false);
const { data: userGroups } = useUserGroups();
// const { data: allUsers } = useUsers({ includeApiKeys: false }) as {
// data: MinimalUserSnapshot[] | undefined;
// };
const { data: users } = useSWR<MinimalUserSnapshot[]>(
"/api/users",
errorHandlingFetcher
);
const mapUsersToMinimalSnapshot = (users: any): MinimalUserSnapshot[] => {
if (!users || !Array.isArray(users.users)) return [];
return users.users.map((user: any) => ({
id: user.id,
name: user.name,
email: user.email,
}));
};
const [deleteModalOpen, setDeleteModalOpen] = useState(false);
if (!labels) {
@@ -346,9 +344,7 @@ export function AssistantEditor({
if (response.ok) {
await refreshAssistants();
router.push(
redirectType === SuccessfulPersonaUpdateRedirectType.ADMIN
? `/admin/assistants?u=${Date.now()}`
: `/chat`
isAdminPage ? `/admin/assistants?u=${Date.now()}` : `/chat`
);
} else {
setPopup({
@@ -374,8 +370,9 @@ export function AssistantEditor({
<BackButton />
</div>
)}
{labelToDelete && (
<DeleteEntityModal
<ConfirmEntityModal
entityType="label"
entityName={labelToDelete.name}
onClose={() => setLabelToDelete(null)}
@@ -398,7 +395,7 @@ export function AssistantEditor({
/>
)}
{deleteModalOpen && existingPersona && (
<DeleteEntityModal
<ConfirmEntityModal
entityType="Persona"
entityName={existingPersona.name}
onClose={closeDeleteModal}
@@ -439,6 +436,7 @@ export function AssistantEditor({
label_ids: Yup.array().of(Yup.number()),
selectedUsers: Yup.array().of(Yup.object()),
selectedGroups: Yup.array().of(Yup.number()),
is_default_persona: Yup.boolean().required(),
})
.test(
"system-prompt-or-task-prompt",
@@ -459,6 +457,19 @@ export function AssistantEditor({
"Must provide either Instructions or Reminders (Advanced)",
});
}
)
.test(
"default-persona-public",
"Default persona must be public",
function (values) {
if (values.is_default_persona && !values.is_public) {
return this.createError({
path: "is_public",
message: "Default persona must be public",
});
}
return true;
}
)}
onSubmit={async (values, formikHelpers) => {
if (
@@ -499,7 +510,6 @@ export function AssistantEditor({
const submissionData: PersonaUpsertParameters = {
...values,
existing_prompt_id: existingPrompt?.id ?? null,
is_default_persona: admin!,
starter_messages: starterMessages,
groups: groups,
users: values.is_public
@@ -563,8 +573,9 @@ export function AssistantEditor({
}
await refreshAssistants();
router.push(
redirectType === SuccessfulPersonaUpdateRedirectType.ADMIN
isAdminPage
? `/admin/assistants?u=${Date.now()}`
: `/chat?assistantId=${assistantId}`
);
@@ -825,10 +836,7 @@ export function AssistantEditor({
</TooltipProvider>
</div>
</div>
<p
className="text-sm text-subtle"
style={{ color: "rgb(113, 114, 121)" }}
>
<p className="text-sm text-neutral-700 dark:text-neutral-400">
Attach additional unique knowledge to this assistant
</p>
</div>
@@ -1008,6 +1016,22 @@ export function AssistantEditor({
{showAdvancedOptions && (
<>
<div className="max-w-4xl w-full">
{user?.role == UserRole.ADMIN && (
<BooleanFormField
onChange={(checked) => {
if (checked) {
setFieldValue("is_public", true);
setFieldValue("is_default_persona", true);
}
}}
name="is_default_persona"
label="Featured Assistant"
subtext="If set, this assistant will be pinned for all new users and appear in the Featured list in the assistant explorer. This also makes the assistant public."
/>
)}
<Separator />
<div className="flex gap-x-2 items-center ">
<div className="block font-medium text-sm">Access</div>
</div>
@@ -1017,22 +1041,60 @@ export function AssistantEditor({
<div className="min-h-[100px]">
<div className="flex items-center mb-2">
<SwitchField
name="is_public"
size="md"
onCheckedChange={(checked) => {
setFieldValue("is_public", checked);
if (checked) {
setFieldValue("selectedUsers", []);
setFieldValue("selectedGroups", []);
}
}}
/>
<TooltipProvider delayDuration={0}>
<Tooltip>
<TooltipTrigger asChild>
<div>
<SwitchField
name="is_public"
size="md"
onCheckedChange={(checked) => {
if (values.is_default_persona && !checked) {
setShowVisibilityWarning(true);
} else {
setFieldValue("is_public", checked);
if (!checked) {
// Even though this code path should not be possible,
// we set the default persona to false to be safe
setFieldValue(
"is_default_persona",
false
);
}
if (checked) {
setFieldValue("selectedUsers", []);
setFieldValue("selectedGroups", []);
}
}
}}
disabled={values.is_default_persona}
/>
</div>
</TooltipTrigger>
{values.is_default_persona && (
<TooltipContent side="top" align="center">
Default persona must be public. Set
&quot;Default Persona&quot; to false to change
visibility.
</TooltipContent>
)}
</Tooltip>
</TooltipProvider>
<span className="text-sm ml-2">
{values.is_public ? "Public" : "Private"}
</span>
</div>
{showVisibilityWarning && (
<div className="flex items-center text-warning mt-2">
<InfoIcon size={16} className="mr-2" />
<span className="text-sm">
Default persona must be public. Visibility has been
automatically set to public.
</span>
</div>
)}
{values.is_public ? (
<p className="text-sm text-text-dark">
Anyone from your organization can view and use this
@@ -1217,7 +1279,7 @@ export function AssistantEditor({
setFieldValue("label_ids", newLabelIds);
}}
itemComponent={({ option }) => (
<div className="flex items-center justify-between px-4 py-3 text-sm hover:bg-hover cursor-pointer border-b border-border last:border-b-0">
<div className="flex items-center justify-between px-4 py-3 text-sm hover:bg-accent-background-hovered cursor-pointer border-b border-border last:border-b-0">
<div
className="flex-grow"
onClick={() => {
@@ -1356,7 +1418,7 @@ export function AssistantEditor({
</>
)}
<div className="mt-12 gap-x-2 w-full justify-end flex">
<div className="mt-12 gap-x-2 w-full justify-end flex">
<Button
type="submit"
disabled={isSubmitting || isRequestSuccessful}

View File

@@ -31,7 +31,7 @@ export function HidableSection({
return (
<div>
<div
className="flex hover:bg-hover-light rounded cursor-pointer p-2"
className="flex hover:bg-accent-background rounded cursor-pointer p-2"
onClick={() => setIsHidden(!isHidden)}
>
<SectionHeader includeMargin={false}>{sectionTitle}</SectionHeader>

View File

@@ -11,13 +11,14 @@ import { DraggableTable } from "@/components/table/DraggableTable";
import {
deletePersona,
personaComparator,
togglePersonaDefault,
togglePersonaVisibility,
} from "./lib";
import { FiEdit2 } from "react-icons/fi";
import { TrashIcon } from "@/components/icons/icons";
import { useUser } from "@/components/user/UserProvider";
import { useAssistants } from "@/components/context/AssistantsContext";
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
function PersonaTypeDisplay({ persona }: { persona: Persona }) {
if (persona.builtin_persona) {
@@ -56,6 +57,9 @@ export function PersonasTable() {
const [finalPersonas, setFinalPersonas] = useState<Persona[]>([]);
const [deleteModalOpen, setDeleteModalOpen] = useState(false);
const [personaToDelete, setPersonaToDelete] = useState<Persona | null>(null);
const [defaultModalOpen, setDefaultModalOpen] = useState(false);
const [personaToToggleDefault, setPersonaToToggleDefault] =
useState<Persona | null>(null);
useEffect(() => {
const editable = editablePersonas.sort(personaComparator);
@@ -126,11 +130,39 @@ export function PersonasTable() {
}
};
const openDefaultModal = (persona: Persona) => {
setPersonaToToggleDefault(persona);
setDefaultModalOpen(true);
};
const closeDefaultModal = () => {
setDefaultModalOpen(false);
setPersonaToToggleDefault(null);
};
const handleToggleDefault = async () => {
if (personaToToggleDefault) {
const response = await togglePersonaDefault(
personaToToggleDefault.id,
personaToToggleDefault.is_default_persona
);
if (response.ok) {
await refreshAssistants();
closeDefaultModal();
} else {
setPopup({
type: "error",
message: `Failed to update persona - ${await response.text()}`,
});
}
}
};
return (
<div>
{popup}
{deleteModalOpen && personaToDelete && (
<DeleteEntityModal
<ConfirmEntityModal
entityType="Persona"
entityName={personaToDelete.name}
onClose={closeDeleteModal}
@@ -138,8 +170,35 @@ export function PersonasTable() {
/>
)}
{defaultModalOpen && personaToToggleDefault && (
<ConfirmEntityModal
variant="action"
entityType="Assistant"
entityName={personaToToggleDefault.name}
onClose={closeDefaultModal}
onSubmit={handleToggleDefault}
actionButtonText={
personaToToggleDefault.is_default_persona
? "Remove Featured"
: "Set as Featured"
}
additionalDetails={
personaToToggleDefault.is_default_persona
? `Removing "${personaToToggleDefault.name}" as a featured assistant will not affect its visibility or accessibility.`
: `Setting "${personaToToggleDefault.name}" as a featured assistant will make it public and visible to all users. This action cannot be undone.`
}
/>
)}
<DraggableTable
headers={["Name", "Description", "Type", "Is Visible", "Delete"]}
headers={[
"Name",
"Description",
"Type",
"Featured Assistant",
"Is Visible",
"Delete",
]}
isAdmin={isAdmin}
rows={finalPersonas.map((persona) => {
const isEditable = editablePersonas.includes(persona);
@@ -152,7 +211,9 @@ export function PersonasTable() {
className="mr-1 my-auto cursor-pointer"
onClick={() =>
router.push(
`/admin/assistants/${persona.id}?u=${Date.now()}`
`/assistants/edit/${
persona.id
}?u=${Date.now()}&admin=true`
)
}
/>
@@ -168,6 +229,30 @@ export function PersonasTable() {
{persona.description}
</p>,
<PersonaTypeDisplay key={persona.id} persona={persona} />,
<div
key="is_default_persona"
onClick={() => {
if (isEditable) {
openDefaultModal(persona);
}
}}
className={`px-1 py-0.5 rounded flex ${
isEditable
? "hover:bg-accent-background-hovered cursor-pointer"
: ""
} select-none w-fit`}
>
<div className="my-auto flex-none w-22">
{!persona.is_default_persona ? (
<div className="text-error">Not Featured</div>
) : (
"Featured"
)}
</div>
<div className="ml-1 my-auto">
<CustomCheckbox checked={persona.is_default_persona} />
</div>
</div>,
<div
key="is_visible"
onClick={async () => {
@@ -187,7 +272,9 @@ export function PersonasTable() {
}
}}
className={`px-1 py-0.5 rounded flex ${
isEditable ? "hover:bg-hover cursor-pointer" : ""
isEditable
? "hover:bg-accent-background-hovered cursor-pointer"
: ""
} select-none w-fit`}
>
<div className="my-auto w-12">
@@ -205,7 +292,7 @@ export function PersonasTable() {
<div className="mr-auto my-auto">
{!persona.builtin_persona && isEditable ? (
<div
className="hover:bg-hover rounded p-1 cursor-pointer"
className="hover:bg-accent-background-hovered rounded p-1 cursor-pointer"
onClick={() => openDeleteModal(persona)}
>
<TrashIcon />

Some files were not shown because too many files have changed in this diff Show More