Compare commits

...

6 Commits

Author SHA1 Message Date
Dane Urban
5848975679 Remove comment 2026-01-08 19:21:24 -08:00
Dane Urban
dcc330010e Remove comment 2026-01-08 19:21:08 -08:00
Dane Urban
d0f5f1f5ae Handle error and log 2026-01-08 19:20:28 -08:00
Dane Urban
3e475993ff Change which event loop we get 2026-01-08 19:16:12 -08:00
Dane Urban
7c2b5fa822 Change loggin 2026-01-08 17:29:00 -08:00
Dane Urban
409cfdc788 nits 2026-01-08 17:23:08 -08:00

View File

@@ -1,6 +1,8 @@
import asyncio
import datetime
import json
import os
from collections.abc import AsyncGenerator
from collections.abc import Generator
from datetime import timedelta
from uuid import UUID
@@ -103,6 +105,7 @@ from onyx.server.utils import PUBLIC_API_TAGS
from onyx.utils.headers import get_custom_tool_additional_request_headers
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.threadpool_concurrency import run_in_background
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -507,7 +510,7 @@ def handle_new_chat_message(
@router.post("/send-chat-message", response_model=None, tags=PUBLIC_API_TAGS)
def handle_send_chat_message(
async def handle_send_chat_message(
chat_message_req: SendMessageRequest,
request: Request,
user: User | None = Depends(current_chat_accessible_user),
@@ -572,34 +575,63 @@ def handle_send_chat_message(
# Note: LLM cost tracking is now handled in multi_llm.py
return result
# Streaming path, normal Onyx UI behavior
def stream_generator() -> Generator[str, None, None]:
# Use prod-cons pattern to continue processing even if request stops yielding
buffer: asyncio.Queue[str | None] = asyncio.Queue()
loop = asyncio.get_running_loop()
# Capture headers before spawning thread
litellm_headers = extract_headers(request.headers, LITELLM_PASS_THROUGH_HEADERS)
custom_tool_headers = get_custom_tool_additional_request_headers(request.headers)
def producer() -> None:
"""
Producer function that runs handle_stream_message_objects in a loop
and writes results to the buffer.
"""
state_container = ChatStateContainer()
try:
logger.debug("Producer started")
with get_session_with_current_tenant() as db_session:
for obj in handle_stream_message_objects(
new_msg_req=chat_message_req,
user=user,
db_session=db_session,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
request.headers
),
litellm_additional_headers=litellm_headers,
custom_tool_additional_headers=custom_tool_headers,
external_state_container=state_container,
):
yield get_json_line(obj.model_dump())
# Thread-safe put into the asyncio queue
loop.call_soon_threadsafe(
buffer.put_nowait, get_json_line(obj.model_dump())
)
# Note: LLM cost tracking is now handled in multi_llm.py
except Exception as e:
logger.exception("Error in chat message streaming")
yield json.dumps({"error": str(e)})
loop.call_soon_threadsafe(buffer.put_nowait, json.dumps({"error": str(e)}))
finally:
logger.debug("Stream generator finished")
# Signal end of stream
loop.call_soon_threadsafe(buffer.put_nowait, None)
logger.debug("Producer finished")
return StreamingResponse(stream_generator(), media_type="text/event-stream")
async def stream_from_buffer() -> AsyncGenerator[str, None]:
"""
Async generator that reads from the buffer and yields to the client.
"""
try:
while True:
item = await buffer.get()
if item is None:
# End of stream signal
break
yield item
except asyncio.CancelledError:
logger.warning("Stream cancelled (Consumer disconnected)")
finally:
logger.debug("Stream consumer finished")
run_in_background(producer)
return StreamingResponse(stream_from_buffer(), media_type="text/event-stream")
@router.put("/set-message-as-latest")