mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-04 14:32:41 +00:00
Compare commits
23 Commits
cli/v0.2.1
...
richard/si
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
84726d3b22 | ||
|
|
89e770ebf2 | ||
|
|
a370cc2ba9 | ||
|
|
ae4fafecbc | ||
|
|
7cb01e57df | ||
|
|
74ed3c146d | ||
|
|
d59c85a7a2 | ||
|
|
9d5a1b6405 | ||
|
|
41accfdc3f | ||
|
|
e9baf77026 | ||
|
|
6f73659eee | ||
|
|
68020a763f | ||
|
|
3d14fe93de | ||
|
|
1d29302647 | ||
|
|
9b600fc2a5 | ||
|
|
37bcf63f4d | ||
|
|
c178cb964a | ||
|
|
09381a4c48 | ||
|
|
5aab51edf4 | ||
|
|
ad1bcfa063 | ||
|
|
dbbbb1918f | ||
|
|
b401bb0a9e | ||
|
|
e042454d05 |
0
backend/onyx/chat/answer_cli.py
Normal file
0
backend/onyx/chat/answer_cli.py
Normal file
720
backend/onyx/chat/answer_scratchpad.py
Normal file
720
backend/onyx/chat/answer_scratchpad.py
Normal file
@@ -0,0 +1,720 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from agents import Agent
|
||||
from agents import AgentHooks
|
||||
from agents import function_tool
|
||||
from agents import ModelSettings
|
||||
from agents import RunContextWrapper
|
||||
from agents import Runner
|
||||
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
|
||||
from agents.extensions.handoff_prompt import RECOMMENDED_PROMPT_PREFIX
|
||||
from agents.extensions.models.litellm_model import LitellmModel
|
||||
from agents.handoffs import HandoffInputData
|
||||
from agents.stream_events import RawResponsesStreamEvent
|
||||
from agents.stream_events import RunItemStreamEvent
|
||||
from braintrust import traced
|
||||
from openai.types import Reasoning
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
|
||||
from onyx.agents.agent_search.dr.dr_prompt_builder import (
|
||||
get_dr_prompt_orchestration_templates,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
|
||||
ExaClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.utils import get_chat_history_string
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.llm.interfaces import (
|
||||
LLM,
|
||||
)
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunDependencies:
|
||||
emitter: Emitter
|
||||
llm: LLM
|
||||
search_tool: SearchTool | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyContext:
|
||||
"""Context class to hold search tool and other dependencies"""
|
||||
|
||||
run_dependencies: RunDependencies | None = None
|
||||
needs_compaction: bool = False
|
||||
|
||||
|
||||
def short_tag(link: str, i: int) -> str:
|
||||
# Stable, readable; index keeps it deterministic across a batch
|
||||
return f"S{i+1}"
|
||||
|
||||
|
||||
@function_tool
|
||||
def web_search(query: str) -> str:
|
||||
"""Search the web for information. This tool provides urls and short snippets,
|
||||
but does not fetch the full content of the urls."""
|
||||
exa_client = ExaClient()
|
||||
hits = exa_client.search(query)
|
||||
results = []
|
||||
for i, r in enumerate(hits):
|
||||
results.append(
|
||||
{
|
||||
"tag": short_tag(r.link, i), # <-- add a tag
|
||||
"title": r.title,
|
||||
"link": r.link,
|
||||
"snippet": r.snippet,
|
||||
"author": r.author,
|
||||
"published_date": (
|
||||
r.published_date.isoformat() if r.published_date else None
|
||||
),
|
||||
}
|
||||
)
|
||||
return json.dumps({"results": results})
|
||||
|
||||
|
||||
@function_tool
|
||||
def web_fetch(urls: List[str]) -> str:
|
||||
"""Fetch the full contents of a list of URLs."""
|
||||
exa_client = ExaClient()
|
||||
docs = exa_client.contents(urls)
|
||||
out = []
|
||||
for i, d in enumerate(docs):
|
||||
out.append(
|
||||
{
|
||||
"tag": short_tag(d.link, i), # <-- add a tag
|
||||
"title": d.title,
|
||||
"link": d.link,
|
||||
"full_content": d.full_content,
|
||||
"published_date": (
|
||||
d.published_date.isoformat() if d.published_date else None
|
||||
),
|
||||
}
|
||||
)
|
||||
return json.dumps({"results": out})
|
||||
|
||||
|
||||
@traced(name="llm_completion", type="llm")
|
||||
def llm_completion(
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
messages: List[Dict[str, Any]],
|
||||
stream: bool = False,
|
||||
) -> litellm.ModelResponse:
|
||||
return litellm.responses(
|
||||
model=model_name,
|
||||
input=messages,
|
||||
tools=[],
|
||||
stream=stream,
|
||||
reasoning=litellm.Reasoning(effort="medium", summary="detailed"),
|
||||
)
|
||||
|
||||
|
||||
@function_tool
|
||||
def internal_search(context_wrapper: RunContextWrapper[MyContext], query: str) -> str:
|
||||
"""Search internal company vector database for information. Sources
|
||||
include:
|
||||
- Fireflies (internal company call transcripts)
|
||||
- Google Drive (internal company documents)
|
||||
- Gmail (internal company emails)
|
||||
- Linear (internal company issues)
|
||||
- Slack (internal company messages)
|
||||
"""
|
||||
context_wrapper.context.run_dependencies.emitter.emit(
|
||||
kind="tool-progress", data={"progress": "Searching internal database"}
|
||||
)
|
||||
search_tool = context_wrapper.context.run_dependencies.search_tool
|
||||
if search_tool is None:
|
||||
raise RuntimeError("Search tool not available in context")
|
||||
|
||||
with get_session_with_current_tenant() as search_db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=query,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=search_db_session,
|
||||
skip_query_analysis=True,
|
||||
original_query=query,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
|
||||
break
|
||||
return retrieved_docs
|
||||
|
||||
|
||||
def _convert_to_packet_obj(packet: Dict[str, Any]) -> Any | None:
|
||||
"""Convert a packet dictionary to PacketObj when possible.
|
||||
|
||||
Args:
|
||||
packet: Dictionary containing packet data
|
||||
|
||||
Returns:
|
||||
PacketObj instance if conversion is possible, None otherwise
|
||||
"""
|
||||
if not isinstance(packet, dict) or "type" not in packet:
|
||||
return None
|
||||
|
||||
packet_type = packet.get("type")
|
||||
if not packet_type:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from onyx.server.query_and_chat.streaming_models import (
|
||||
MessageStart,
|
||||
MessageDelta,
|
||||
OverallStop,
|
||||
)
|
||||
|
||||
if packet_type == "response.output_item.added":
|
||||
return MessageStart(
|
||||
type="message_start",
|
||||
content="",
|
||||
final_documents=None,
|
||||
)
|
||||
elif packet_type == "response.output_text.delta":
|
||||
return MessageDelta(type="message_delta", content=packet["delta"])
|
||||
elif packet_type == "response.completed":
|
||||
return OverallStop(type="stop")
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but don't fail the entire process
|
||||
logger.debug(f"Failed to convert packet to PacketObj: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# stream_bus.py
|
||||
@dataclass
|
||||
class StreamPacket:
|
||||
kind: str # "agent" | "tool-progress" | "done"
|
||||
payload: Dict[str, Any] = None
|
||||
|
||||
|
||||
class Emitter:
|
||||
"""Use this inside tools to emit arbitrary UI progress."""
|
||||
|
||||
def __init__(self, bus: Queue):
|
||||
self.bus = bus
|
||||
|
||||
def emit(self, kind: str, data: Dict[str, Any]) -> None:
|
||||
self.bus.put(StreamPacket(kind=kind, payload=data))
|
||||
|
||||
|
||||
# If we want durable execution in the future, we can replace this with a temporal call
|
||||
def start_run_in_thread(
|
||||
agent: Agent,
|
||||
messages: List[Dict[str, Any]],
|
||||
cfg: GraphConfig,
|
||||
llm: LLM,
|
||||
emitter: Emitter,
|
||||
search_tool: SearchTool | None = None,
|
||||
) -> threading.Thread:
|
||||
def worker():
|
||||
async def amain():
|
||||
ctx = MyContext(
|
||||
run_dependencies=RunDependencies(
|
||||
search_tool=search_tool,
|
||||
emitter=emitter,
|
||||
llm=llm,
|
||||
)
|
||||
)
|
||||
# 1) start the streamed run (async)
|
||||
streamed = Runner.run_streamed(agent, messages, context=ctx)
|
||||
|
||||
# 2) forward the agent’s async event stream
|
||||
async for ev in streamed.stream_events():
|
||||
if isinstance(ev, RunItemStreamEvent):
|
||||
pass
|
||||
elif isinstance(ev, RawResponsesStreamEvent):
|
||||
emitter.emit(kind="agent", data=ev.data.model_dump())
|
||||
|
||||
emitter.emit(kind="done", data={"ok": True})
|
||||
|
||||
# run the async main inside this thread
|
||||
asyncio.run(amain())
|
||||
|
||||
t = threading.Thread(target=worker, daemon=True)
|
||||
t.start()
|
||||
return t
|
||||
|
||||
|
||||
class ResearchScratchpad(BaseModel):
|
||||
notes: List[dict] = []
|
||||
|
||||
|
||||
scratchpad = ResearchScratchpad()
|
||||
|
||||
|
||||
@function_tool
|
||||
def add_note(note: str, source_url: str | None = None):
|
||||
"""Store a factual note you want to cite later."""
|
||||
scratchpad.notes.append({"note": note, "source_url": source_url})
|
||||
return {"ok": True, "count": len(scratchpad.notes)}
|
||||
|
||||
|
||||
@function_tool
|
||||
def finalize_report():
|
||||
"""Signal you're done researching. Return a structured, citation-rich report."""
|
||||
# The model should *compose* the report as the tool *result*, using notes in scratchpad.
|
||||
# Some teams have the model return the full report as this tool's return value
|
||||
# so the UI can detect completion cleanly.
|
||||
return {
|
||||
"status": "ready_to_render",
|
||||
"notes_index": scratchpad.notes, # the model can read these to assemble citations
|
||||
}
|
||||
|
||||
|
||||
class CompactionHooks(AgentHooks[Any]):
|
||||
async def on_llm_start(
|
||||
self,
|
||||
context: RunContextWrapper[MyContext],
|
||||
agent: Agent[Any],
|
||||
system_prompt: Optional[str],
|
||||
input_items: List[dict],
|
||||
) -> None:
|
||||
print(f"[{agent.name}] LLM start")
|
||||
print("system_prompt:", system_prompt)
|
||||
print("usage so far:", context.usage.total_tokens)
|
||||
usage = context.usage.total_tokens
|
||||
if usage > 10000:
|
||||
context.context.needs_compaction = True
|
||||
|
||||
|
||||
def compaction_input_filter(input_data: HandoffInputData):
|
||||
filtered_messages = []
|
||||
for msg in input_data.input_history[:-1]:
|
||||
if isinstance(msg, dict) and msg.get("content") is not None:
|
||||
# Convert tool messages to user messages to avoid API errors
|
||||
if msg.get("role") == "tool":
|
||||
filtered_msg = {
|
||||
"role": "user",
|
||||
"content": f"Tool response: {msg.get('content', '')}",
|
||||
}
|
||||
filtered_messages.append(filtered_msg)
|
||||
else:
|
||||
filtered_messages.append(msg)
|
||||
|
||||
# Only proceed with compaction if we have valid messages
|
||||
if filtered_messages:
|
||||
return [filtered_messages[-1]]
|
||||
|
||||
|
||||
def construct_deep_research_agent(llm: LLM) -> Agent:
|
||||
litellm_model = LitellmModel(
|
||||
# If you have access, prefer OpenAI’s deep research-capable models:
|
||||
# "o3-deep-research" or "o4-mini-deep-research"
|
||||
# otherwise keep your current model and lean on the prompt + tools
|
||||
model=llm.config.model_name,
|
||||
api_key=llm.config.api_key,
|
||||
)
|
||||
|
||||
DR_INSTRUCTIONS = f"""
|
||||
{RECOMMENDED_PROMPT_PREFIX}
|
||||
You are a deep-research agent. Work in explicit iterations:
|
||||
1) PLAN: Decompose the user’s query into sub-questions and a step-by-step plan.
|
||||
2) SEARCH: Use web_search to explore multiple angles, fanning out and searching in parallel.
|
||||
3) FETCH: Use web_fetch for any promising URLs to extract specifics and quotes.
|
||||
4) NOTE: After each useful find, call add_note(note, source_url) to save key facts.
|
||||
5) REVISE: If evidence contradicts earlier assumptions, update your plan and continue.
|
||||
6) FINALIZE: When confident, call finalize_report(). Your final answer must include:
|
||||
- Clear, structured conclusions
|
||||
- A short “How I searched” summary
|
||||
- Inline citations to sources (with URLs)
|
||||
- A bullet list of limitations/open questions
|
||||
Guidelines:
|
||||
- Prefer breadth-first exploration before deep dives.
|
||||
- Compare sources and dates; prioritize recency for time-sensitive topics.
|
||||
- Minimize redundancy by skimming before fetching.
|
||||
- Think out loud in a compact way, but keep reasoning crisp.
|
||||
- If context exceeds 10000 tokens, handoff to the compactor agent.
|
||||
"""
|
||||
return Agent(
|
||||
name="Researcher",
|
||||
instructions=DR_INSTRUCTIONS,
|
||||
model=litellm_model,
|
||||
tools=[web_search, web_fetch, add_note, finalize_report, internal_search],
|
||||
model_settings=ModelSettings(
|
||||
temperature=llm.config.temperature,
|
||||
include_usage=True,
|
||||
parallel_tool_calls=True,
|
||||
# optional: let model choose tools freely
|
||||
# tool_choice="auto", # if supported by your LitellmModel wrapper
|
||||
),
|
||||
hooks=CompactionHooks(),
|
||||
)
|
||||
|
||||
|
||||
def unified_event_stream(
|
||||
messages: List[Dict[str, Any]],
|
||||
cfg: GraphConfig,
|
||||
llm: LLM,
|
||||
emitter: Emitter,
|
||||
search_tool: SearchTool | None = None,
|
||||
) -> Generator[Dict[str, Any], None, None]:
|
||||
bus: Queue = Queue()
|
||||
emitter = Emitter(bus)
|
||||
current_context = contextvars.copy_context()
|
||||
t = threading.Thread(
|
||||
target=current_context.run,
|
||||
args=(
|
||||
# thread_worker_dr_turn,
|
||||
thread_worker_simple_turn,
|
||||
messages,
|
||||
cfg,
|
||||
llm,
|
||||
emitter,
|
||||
search_tool,
|
||||
), # eval_context=None for now
|
||||
daemon=True,
|
||||
)
|
||||
t.start()
|
||||
done = False
|
||||
while not done:
|
||||
pkt: StreamPacket = emitter.bus.get()
|
||||
if pkt.kind == "done":
|
||||
done = True
|
||||
else:
|
||||
# Convert packet to PacketObj when possible
|
||||
packet_obj = _convert_to_packet_obj(pkt.payload)
|
||||
if packet_obj:
|
||||
# Convert PacketObj back to dict for compatibility
|
||||
yield packet_obj.model_dump()
|
||||
else:
|
||||
# Fallback to original payload
|
||||
yield pkt.payload
|
||||
|
||||
|
||||
# This should be close to the API
|
||||
def stream_chat_sync(
|
||||
messages: List[Dict[str, Any]],
|
||||
cfg: GraphConfig,
|
||||
llm: LLM,
|
||||
search_tool: SearchTool | None = None,
|
||||
) -> Generator[Dict[str, Any], None, None]:
|
||||
bus: Queue = Queue()
|
||||
emitter = Emitter(bus)
|
||||
return unified_event_stream(
|
||||
messages=messages,
|
||||
cfg=cfg,
|
||||
llm=llm,
|
||||
emitter=emitter,
|
||||
search_tool=search_tool,
|
||||
)
|
||||
|
||||
|
||||
def construct_simple_agent(
|
||||
llm: LLM,
|
||||
) -> Agent:
|
||||
litellm_model = LitellmModel(
|
||||
model="o3-mini",
|
||||
api_key=llm.config.api_key,
|
||||
)
|
||||
return Agent(
|
||||
name="Assistant",
|
||||
instructions="""
|
||||
You are a helpful assistant that can search the web, fetch content from URLs,
|
||||
and search internal databases. Please do some reasoning and then return your answer.
|
||||
""",
|
||||
model=litellm_model,
|
||||
tools=[web_search, web_fetch, internal_search],
|
||||
model_settings=ModelSettings(
|
||||
temperature=0.0,
|
||||
include_usage=True, # Track usage metrics
|
||||
reasoning=Reasoning(
|
||||
effort="medium", summary="detailed", generate_summary="detailed"
|
||||
),
|
||||
verbose=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def thread_worker_dr_turn(messages, cfg, llm, emitter, search_tool):
|
||||
"""
|
||||
Worker function for deep research turn that runs in a separate thread.
|
||||
|
||||
Args:
|
||||
messages: List of messages for the conversation
|
||||
cfg: Graph configuration
|
||||
llm: Language model instance
|
||||
emitter: Event emitter for streaming responses
|
||||
search_tool: Search tool instance (optional)
|
||||
eval_context: Evaluation context to be propagated to the worker thread
|
||||
"""
|
||||
try:
|
||||
dr_turn(messages, cfg, llm, emitter, search_tool)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in dr_turn: {e}", exc_info=e, stack_info=True)
|
||||
emitter.emit(kind="done", data={"ok": False})
|
||||
|
||||
|
||||
def thread_worker_simple_turn(messages, cfg, llm, emitter, search_tool):
|
||||
try:
|
||||
simple_turn(
|
||||
messages=messages,
|
||||
cfg=cfg,
|
||||
llm=llm,
|
||||
turn_event_stream_emitter=emitter,
|
||||
search_tool=search_tool,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in simple_turn: {e}", exc_info=e, stack_info=True)
|
||||
emitter.emit(kind="done", data={"ok": False})
|
||||
|
||||
|
||||
SENTINEL = object()
|
||||
|
||||
|
||||
class StreamBridge:
|
||||
"""
|
||||
Spins up an asyncio loop in a background thread, starts Runner.run_streamed there,
|
||||
consumes its async event stream, and exposes a blocking .events() iterator.
|
||||
"""
|
||||
|
||||
def __init__(self, agent, messages, ctx, max_turns: int = 100):
|
||||
self.agent = agent
|
||||
self.messages = messages
|
||||
self.ctx = ctx
|
||||
self.max_turns = max_turns
|
||||
|
||||
self._q: "queue.Queue[object]" = queue.Queue()
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._streamed = None
|
||||
|
||||
def start(self):
|
||||
def worker():
|
||||
async def run_and_consume():
|
||||
# Create the streamed run *inside* the loop thread
|
||||
self._streamed = Runner.run_streamed(
|
||||
self.agent,
|
||||
self.messages,
|
||||
context=self.ctx,
|
||||
max_turns=self.max_turns,
|
||||
)
|
||||
try:
|
||||
async for ev in self._streamed.stream_events():
|
||||
self._q.put(ev)
|
||||
finally:
|
||||
self._q.put(SENTINEL)
|
||||
|
||||
# Each thread needs its own loop
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
try:
|
||||
self._loop.run_until_complete(run_and_consume())
|
||||
finally:
|
||||
self._loop.close()
|
||||
|
||||
self._thread = threading.Thread(target=worker, daemon=True)
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def events(self) -> Iterator[object]:
|
||||
while True:
|
||||
ev = self._q.get()
|
||||
if ev is SENTINEL:
|
||||
break
|
||||
yield ev
|
||||
|
||||
def cancel(self):
|
||||
# Post a cancellation to the loop thread safely
|
||||
if self._loop and self._streamed:
|
||||
|
||||
def _do_cancel():
|
||||
try:
|
||||
self._streamed.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._loop.call_soon_threadsafe(_do_cancel)
|
||||
|
||||
|
||||
def simple_turn(
|
||||
messages: List[Dict[str, Any]],
|
||||
cfg: GraphConfig,
|
||||
llm: LLM,
|
||||
turn_event_stream_emitter: Emitter,
|
||||
search_tool: SearchTool | None = None,
|
||||
) -> None:
|
||||
llm_response = llm_completion(
|
||||
model_name="gpt-5-mini",
|
||||
temperature=0.0,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
)
|
||||
llm_response.json()
|
||||
simple_agent = construct_simple_agent(llm)
|
||||
ctx = MyContext(
|
||||
run_dependencies=RunDependencies(
|
||||
search_tool=search_tool, emitter=turn_event_stream_emitter, llm=llm
|
||||
)
|
||||
)
|
||||
bridge = StreamBridge(simple_agent, messages, ctx, max_turns=100).start()
|
||||
for ev in bridge.events():
|
||||
if isinstance(ev, RunItemStreamEvent):
|
||||
print("RUN ITEM STREAM EVENT!")
|
||||
if ev.name == "reasoning_item_created":
|
||||
print("REASONING!")
|
||||
turn_event_stream_emitter.emit(
|
||||
kind="reasoning", data=ev.item.raw_item.model_dump()
|
||||
)
|
||||
elif isinstance(ev, RawResponsesStreamEvent):
|
||||
print("RAW RESPONSES STREAM EVENT!")
|
||||
print(ev.type)
|
||||
turn_event_stream_emitter.emit(kind="agent", data=ev.data.model_dump())
|
||||
turn_event_stream_emitter.emit(kind="done", data={"ok": True})
|
||||
|
||||
|
||||
def dr_turn(
|
||||
messages: List[Dict[str, Any]],
|
||||
cfg: GraphConfig,
|
||||
llm: LLM,
|
||||
turn_event_stream_emitter: Emitter, # TurnEventStream is the primary output of the turn
|
||||
search_tool: SearchTool | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Execute a deep research turn with evaluation context support.
|
||||
|
||||
Args:
|
||||
messages: List of messages for the conversation
|
||||
cfg: Graph configuration
|
||||
llm: Language model instance
|
||||
turn_event_stream_emitter: Event emitter for streaming responses
|
||||
search_tool: Search tool instance (optional)
|
||||
eval_context: Evaluation context for the turn (optional)
|
||||
"""
|
||||
clarification = get_clarification(
|
||||
messages, cfg, llm, turn_event_stream_emitter, search_tool
|
||||
)
|
||||
output = json.loads(clarification.choices[0].message.content)
|
||||
clarification_output = ClarificationOutput(**output)
|
||||
if clarification_output.clarification_needed:
|
||||
turn_event_stream_emitter.emit(
|
||||
kind="agent", data=clarification_output.clarification_question
|
||||
)
|
||||
turn_event_stream_emitter.emit(kind="done", data={"ok": True})
|
||||
return
|
||||
dr_agent = construct_deep_research_agent(llm)
|
||||
ctx = MyContext(
|
||||
run_dependencies=RunDependencies(
|
||||
search_tool=search_tool,
|
||||
emitter=turn_event_stream_emitter,
|
||||
llm=llm,
|
||||
)
|
||||
)
|
||||
bridge = StreamBridge(dr_agent, messages, ctx, max_turns=100).start()
|
||||
for ev in bridge.events():
|
||||
if isinstance(ev, RunItemStreamEvent):
|
||||
pass
|
||||
elif isinstance(ev, RawResponsesStreamEvent):
|
||||
turn_event_stream_emitter.emit(kind="agent", data=ev.data.model_dump())
|
||||
|
||||
turn_event_stream_emitter.emit(kind="done", data={"ok": True})
|
||||
|
||||
|
||||
class ClarificationOutput(BaseModel):
|
||||
clarification_question: str
|
||||
clarification_needed: bool
|
||||
|
||||
|
||||
def get_clarification(
|
||||
messages: List[Dict[str, Any]],
|
||||
cfg: GraphConfig,
|
||||
llm: LLM,
|
||||
emitter: Emitter,
|
||||
search_tool: SearchTool | None = None,
|
||||
) -> litellm.ModelResponse:
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
cfg.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
base_clarification_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.CLARIFICATION,
|
||||
research_type=ResearchType.DEEP,
|
||||
entity_types_string=None,
|
||||
relationship_types_string=None,
|
||||
available_tools={},
|
||||
)
|
||||
clarification_prompt = base_clarification_prompt.build(
|
||||
question=messages[-1]["content"],
|
||||
chat_history_string=chat_history_string,
|
||||
)
|
||||
clarifier_prompt = prompt_with_handoff_instructions(clarification_prompt)
|
||||
llm_response = llm_completion(
|
||||
model_name=llm.config.model_name,
|
||||
temperature=llm.config.temperature,
|
||||
messages=[{"role": "user", "content": clarifier_prompt}],
|
||||
stream=False,
|
||||
)
|
||||
return llm_response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": """
|
||||
Let $N$ denote the number of ordered triples of positive integers $(a, b, c)$ such that $a, b, c
|
||||
\\leq 3^6$ and $a^3 + b^3 + c^3$ is a multiple of $3^7$. Find the remainder when $N$ is divided by $1000$.
|
||||
""",
|
||||
}
|
||||
]
|
||||
# OpenAI reasoning is not supported yet due to: https://github.com/BerriAI/litellm/pull/14117
|
||||
reasoning_agent = Agent(
|
||||
name="Reasoning",
|
||||
instructions="You are a reasoning agent. You are given a question and you need to reason about it.",
|
||||
model=LitellmModel(
|
||||
model="gpt-5-mini",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
),
|
||||
tools=[],
|
||||
model_settings=ModelSettings(
|
||||
temperature=0.0,
|
||||
reasoning=Reasoning(effort="medium", summary="detailed"),
|
||||
),
|
||||
)
|
||||
llm_response = llm_completion(
|
||||
model_name="gpt-5-mini",
|
||||
temperature=0.0,
|
||||
messages=messages,
|
||||
stream=False,
|
||||
)
|
||||
x = llm_response.json()
|
||||
print(x)
|
||||
@@ -3,13 +3,16 @@ import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Protocol
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.answer_scratchpad import stream_chat_sync
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
from onyx.chat.chat_utils import process_kg_commands
|
||||
@@ -24,9 +27,6 @@ from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import UserKnowledgeFilePacket
|
||||
from onyx.chat.packet_proccessing.process_streamed_packets import (
|
||||
process_streamed_packets,
|
||||
)
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
@@ -41,7 +41,6 @@ from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.retrieval.search_runner import (
|
||||
inference_sections_from_ids,
|
||||
)
|
||||
@@ -76,11 +75,7 @@ from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
@@ -671,10 +666,49 @@ def stream_chat_message_objects(
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
type_to_role = {
|
||||
"human": "user",
|
||||
"assistant": "assistant",
|
||||
"system": "system",
|
||||
"function": "function",
|
||||
}
|
||||
SYSTEM_PROMPT = """
|
||||
You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the \
|
||||
user's intent, ask clarifying questions when needed, think step-by-step through complex problems, \
|
||||
provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always \
|
||||
prioritize being truthful, nuanced, insightful, and efficient.
|
||||
The current date is September 18, 2025.
|
||||
|
||||
# Process streamed packets using the new packet processing module
|
||||
yield from process_streamed_packets(
|
||||
answer_processed_output=answer.processed_streamed_output,
|
||||
You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make \
|
||||
your responses more readable and engaging.
|
||||
You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, \
|
||||
symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline.
|
||||
For code you prefer to use Markdown and specify the language.
|
||||
You can use Markdown horizontal rules (---) to separate sections of your responses.
|
||||
You can use Markdown tables to format your responses for data, lists, and other structured information.
|
||||
|
||||
You must cite inline using tags from tool results.
|
||||
|
||||
Rules:
|
||||
- Only cite sources provided by the tools (use each item’s "tag" field).
|
||||
- Place the citation immediately after the claim it supports, like this: "... result [S1](https://linkforS1)" or
|
||||
"... results [S1](https://linkforS1)[S3](https://linkforS3)".
|
||||
- If multiple sentences in a row are supported by the same source, cite the first sentence;
|
||||
then omit repeats until the source changes.
|
||||
- Never invent tags. If no source supports a claim, say so.
|
||||
- Do not add a separate “Sources” section unless asked.
|
||||
"""
|
||||
system_message = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
other_messages = [
|
||||
{"role": type_to_role[message.type], "content": message.content}
|
||||
for message in answer.graph_inputs.prompt_builder.build()
|
||||
if message.type != "system"
|
||||
]
|
||||
yield from stream_chat_sync(
|
||||
messages=system_message + other_messages,
|
||||
cfg=answer.graph_config,
|
||||
llm=answer.graph_tooling.primary_llm,
|
||||
search_tool=answer.graph_tooling.search_tool,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
@@ -736,7 +770,15 @@ def stream_chat_message(
|
||||
document_retrieval_latency = time.time() - start_time
|
||||
logger.debug(f"First doc time: {document_retrieval_latency}")
|
||||
|
||||
yield get_json_line(obj.model_dump())
|
||||
# Convert Pydantic models to dictionaries for JSON serialization
|
||||
if hasattr(obj, "model_dump"):
|
||||
obj_dict = obj.model_dump()
|
||||
elif hasattr(obj, "dict"):
|
||||
obj_dict = obj.dict()
|
||||
else:
|
||||
obj_dict = obj
|
||||
|
||||
yield get_json_line(obj_dict)
|
||||
|
||||
|
||||
def remove_answer_citations(answer: str) -> str:
|
||||
@@ -745,48 +787,98 @@ def remove_answer_citations(answer: str) -> str:
|
||||
return re.sub(pattern, "", answer)
|
||||
|
||||
|
||||
def _convert_to_packet_obj(packet: Dict[str, Any]) -> Any | None:
|
||||
"""Convert a packet dictionary to PacketObj when possible.
|
||||
|
||||
Args:
|
||||
packet: Dictionary containing packet data
|
||||
|
||||
Returns:
|
||||
PacketObj instance if conversion is possible, None otherwise
|
||||
"""
|
||||
if not isinstance(packet, dict) or "type" not in packet:
|
||||
return None
|
||||
|
||||
packet_type = packet.get("type")
|
||||
if not packet_type:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from onyx.server.query_and_chat.streaming_models import (
|
||||
MessageStart,
|
||||
MessageDelta,
|
||||
OverallStop,
|
||||
SectionEnd,
|
||||
SearchToolStart,
|
||||
SearchToolDelta,
|
||||
ImageGenerationToolStart,
|
||||
ImageGenerationToolDelta,
|
||||
ImageGenerationToolHeartbeat,
|
||||
CustomToolStart,
|
||||
CustomToolDelta,
|
||||
ReasoningStart,
|
||||
ReasoningDelta,
|
||||
CitationStart,
|
||||
CitationDelta,
|
||||
)
|
||||
|
||||
# Map packet types to their corresponding classes
|
||||
type_mapping = {
|
||||
"message_start": MessageStart,
|
||||
"message_delta": MessageDelta,
|
||||
"stop": OverallStop,
|
||||
"section_end": SectionEnd,
|
||||
"internal_search_tool_start": SearchToolStart,
|
||||
"internal_search_tool_delta": SearchToolDelta,
|
||||
"image_generation_tool_start": ImageGenerationToolStart,
|
||||
"image_generation_tool_delta": ImageGenerationToolDelta,
|
||||
"image_generation_tool_heartbeat": ImageGenerationToolHeartbeat,
|
||||
"custom_tool_start": CustomToolStart,
|
||||
"custom_tool_delta": CustomToolDelta,
|
||||
"reasoning_start": ReasoningStart,
|
||||
"reasoning_delta": ReasoningDelta,
|
||||
"citation_start": CitationStart,
|
||||
"citation_delta": CitationDelta,
|
||||
}
|
||||
|
||||
packet_class = type_mapping.get(packet_type)
|
||||
if packet_class:
|
||||
# Create instance using the packet data, filtering out None values
|
||||
filtered_data = {k: v for k, v in packet.items() if v is not None}
|
||||
return packet_class(**filtered_data)
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but don't fail the entire process
|
||||
logger.debug(f"Failed to convert packet to PacketObj: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def gather_stream(
|
||||
packets: AnswerStream,
|
||||
packets: Iterator[Dict[str, Any]],
|
||||
) -> ChatBasicResponse:
|
||||
answer = ""
|
||||
citations: list[CitationInfo] = []
|
||||
error_msg: str | None = None
|
||||
message_id: int | None = None
|
||||
top_documents: list[SavedSearchDoc] = []
|
||||
|
||||
for packet in packets:
|
||||
if isinstance(packet, Packet):
|
||||
# Handle the different packet object types
|
||||
if isinstance(packet.obj, MessageStart):
|
||||
# MessageStart contains the initial content and final documents
|
||||
if packet.obj.content:
|
||||
answer += packet.obj.content
|
||||
if packet.obj.final_documents:
|
||||
top_documents = packet.obj.final_documents
|
||||
elif isinstance(packet.obj, MessageDelta):
|
||||
# MessageDelta contains incremental content updates
|
||||
if packet.obj.content:
|
||||
answer += packet.obj.content
|
||||
elif isinstance(packet.obj, CitationDelta):
|
||||
# CitationDelta contains citation information
|
||||
if packet.obj.citations:
|
||||
citations.extend(packet.obj.citations)
|
||||
elif isinstance(packet, StreamingError):
|
||||
error_msg = packet.error
|
||||
elif isinstance(packet, MessageResponseIDInfo):
|
||||
message_id = packet.reserved_assistant_message_id
|
||||
if packet != {"type": "event"}:
|
||||
print(packet)
|
||||
|
||||
if message_id is None:
|
||||
raise ValueError("Message ID is required")
|
||||
# Convert packet to PacketObj when possible
|
||||
packet_obj = _convert_to_packet_obj(packet)
|
||||
if packet_obj:
|
||||
# Handle PacketObj types that contain text content
|
||||
if hasattr(packet_obj, "content") and packet_obj.content:
|
||||
answer += packet_obj.content
|
||||
elif "text" in packet:
|
||||
# Fallback for legacy packet format
|
||||
answer += packet["text"]
|
||||
|
||||
return ChatBasicResponse(
|
||||
answer=answer,
|
||||
answer_citationless=remove_answer_citations(answer),
|
||||
cited_documents={
|
||||
citation.citation_num: citation.document_id for citation in citations
|
||||
},
|
||||
message_id=message_id,
|
||||
error_msg=error_msg,
|
||||
top_documents=top_documents,
|
||||
cited_documents={},
|
||||
message_id=0,
|
||||
error_msg=None,
|
||||
top_documents=[],
|
||||
)
|
||||
|
||||
79
backend/onyx/evals/demo_agent.py
Normal file
79
backend/onyx/evals/demo_agent.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from agents import ModelSettings
|
||||
from agents import run_demo_loop
|
||||
from agents.agent import Agent
|
||||
from agents.extensions.handoff_prompt import prompt_with_handoff_instructions
|
||||
from agents.extensions.models.litellm_model import LitellmModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.dr.dr_prompt_builder import (
|
||||
get_dr_prompt_orchestration_templates,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
|
||||
|
||||
def construct_simple_agent() -> Agent:
|
||||
litellm_model = LitellmModel(
|
||||
model="gpt-4.1",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
return Agent(
|
||||
name="Assistant",
|
||||
instructions="""
|
||||
You are a helpful assistant that can search the web, fetch content from URLs,
|
||||
and search internal databases.
|
||||
""",
|
||||
model=litellm_model,
|
||||
tools=[],
|
||||
model_settings=ModelSettings(
|
||||
temperature=0.0,
|
||||
include_usage=True, # Track usage metrics
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ClarificationOutput(BaseModel):
|
||||
clarification_question: str
|
||||
clarification_needed: bool
|
||||
|
||||
|
||||
def construct_dr_agent() -> Agent:
|
||||
simple_agent = construct_simple_agent()
|
||||
litellm_model = LitellmModel(
|
||||
model="gpt-4.1",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
base_clarification_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.CLARIFICATION,
|
||||
research_type=ResearchType.DEEP,
|
||||
entity_types_string=None,
|
||||
relationship_types_string=None,
|
||||
available_tools={},
|
||||
)
|
||||
clarification_prompt = base_clarification_prompt.build(
|
||||
question="",
|
||||
chat_history_string="",
|
||||
)
|
||||
clarifier_prompt = prompt_with_handoff_instructions(clarification_prompt)
|
||||
clarifier_agent = Agent(
|
||||
name="Clarifier",
|
||||
instructions=clarifier_prompt,
|
||||
model=litellm_model,
|
||||
tools=[],
|
||||
output_type=ClarificationOutput,
|
||||
handoffs=[simple_agent],
|
||||
model_settings=ModelSettings(tool_choice="required"),
|
||||
)
|
||||
return clarifier_agent
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
agent = construct_dr_agent()
|
||||
await run_demo_loop(agent)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -100,6 +100,7 @@ def run_eval(
|
||||
data: list[dict[str, dict[str, str]]] | None = None,
|
||||
remote_dataset_name: str | None = None,
|
||||
provider: EvalProvider = get_default_provider(),
|
||||
no_send_logs: bool = False,
|
||||
) -> EvalationAck:
|
||||
if data is not None and remote_dataset_name is not None:
|
||||
raise ValueError("Cannot specify both data and remote_dataset_name")
|
||||
@@ -112,4 +113,5 @@ def run_eval(
|
||||
configuration=configuration,
|
||||
data=data,
|
||||
remote_dataset_name=remote_dataset_name,
|
||||
no_send_logs=no_send_logs,
|
||||
)
|
||||
|
||||
@@ -42,6 +42,7 @@ def run_local(
|
||||
local_data_path: str | None,
|
||||
remote_dataset_name: str | None,
|
||||
search_permissions_email: str | None = None,
|
||||
no_send_logs: bool = False,
|
||||
) -> EvalationAck:
|
||||
"""
|
||||
Run evaluation with local configurations.
|
||||
@@ -67,7 +68,9 @@ def run_local(
|
||||
|
||||
if remote_dataset_name:
|
||||
score = run_eval(
|
||||
configuration=configuration, remote_dataset_name=remote_dataset_name
|
||||
configuration=configuration,
|
||||
remote_dataset_name=remote_dataset_name,
|
||||
no_send_logs=no_send_logs,
|
||||
)
|
||||
else:
|
||||
if local_data_path is None:
|
||||
@@ -75,7 +78,9 @@ def run_local(
|
||||
"local_data_path or remote_dataset_name is required for local evaluation"
|
||||
)
|
||||
data = load_data_local(local_data_path)
|
||||
score = run_eval(configuration=configuration, data=data)
|
||||
score = run_eval(
|
||||
configuration=configuration, data=data, no_send_logs=no_send_logs
|
||||
)
|
||||
|
||||
return score
|
||||
|
||||
@@ -172,6 +177,13 @@ def main() -> None:
|
||||
help="Email address to impersonate for the evaluation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-send-logs",
|
||||
action="store_true",
|
||||
help="Do not send logs to the remote server",
|
||||
default=False,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.local_data_path:
|
||||
@@ -215,6 +227,7 @@ def main() -> None:
|
||||
local_data_path=args.local_data_path,
|
||||
remote_dataset_name=args.remote_dataset_name,
|
||||
search_permissions_email=args.search_permissions_email,
|
||||
no_send_logs=args.no_send_logs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -78,5 +78,6 @@ class EvalProvider(ABC):
|
||||
configuration: EvalConfigurationOptions,
|
||||
data: list[dict[str, dict[str, str]]] | None = None,
|
||||
remote_dataset_name: str | None = None,
|
||||
no_send_logs: bool = False,
|
||||
) -> EvalationAck:
|
||||
pass
|
||||
|
||||
@@ -109,8 +109,7 @@ def parse_csv_file(csv_path: str) -> List[Dict[str, Any]]:
|
||||
records.extend(
|
||||
[
|
||||
{
|
||||
"question": question
|
||||
+ ". All info is contained in the quesiton. DO NOT ask any clarifying questions.",
|
||||
"question": question,
|
||||
"research_type": "DEEP",
|
||||
"categories": categories,
|
||||
"expected_depth": expected_depth,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from autoevals.llm import LLMClassifier
|
||||
from braintrust import Eval
|
||||
from braintrust import EvalCase
|
||||
from braintrust import init_dataset
|
||||
@@ -11,6 +12,33 @@ from onyx.evals.models import EvalConfigurationOptions
|
||||
from onyx.evals.models import EvalProvider
|
||||
|
||||
|
||||
quality_classifier = LLMClassifier(
|
||||
name="quality",
|
||||
prompt_template="""
|
||||
You are a customer doing a trial of the product Onyx. Onyx provides a UI for users to chat with an LLM
|
||||
and search for information, similar to ChatGPT. You think ChatGPT's answer quality is great, and
|
||||
you want to rate Onyx's response relativeto ChatGPT's response.\n
|
||||
[Question]: {{input}}\n
|
||||
[ChatGPT Answer]: {{expected}}\n
|
||||
[Onyx Answer]: {{output}}\n
|
||||
|
||||
Please rate the quality of the Onyx answer relative to the ChatGPT answer on a scale of A to E:
|
||||
A: The Onyx answer is great and is as good or better than the ChatGPT answer.
|
||||
B: The Onyx answer is good and and comparable to the ChatGPT answer.
|
||||
C: The Onyx answer is fair.
|
||||
D: The Onyx answer is poor and is worse than the ChatGPT answer.
|
||||
E: The Onyx answer is terrible and is much worse than the ChatGPT answer.
|
||||
""",
|
||||
choice_scores={
|
||||
"A": 1,
|
||||
"B": 0.75,
|
||||
"C": 0.5,
|
||||
"D": 0.25,
|
||||
"E": 0,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class BraintrustEvalProvider(EvalProvider):
|
||||
def eval(
|
||||
self,
|
||||
@@ -18,6 +46,7 @@ class BraintrustEvalProvider(EvalProvider):
|
||||
configuration: EvalConfigurationOptions,
|
||||
data: list[dict[str, dict[str, str]]] | None = None,
|
||||
remote_dataset_name: str | None = None,
|
||||
no_send_logs: bool = False,
|
||||
) -> EvalationAck:
|
||||
if data is not None and remote_dataset_name is not None:
|
||||
raise ValueError("Cannot specify both data and remote_dataset_name")
|
||||
@@ -35,6 +64,7 @@ class BraintrustEvalProvider(EvalProvider):
|
||||
scores=[],
|
||||
metadata={**configuration.model_dump()},
|
||||
max_concurrency=BRAINTRUST_MAX_CONCURRENCY,
|
||||
no_send_logs=no_send_logs,
|
||||
)
|
||||
else:
|
||||
if data is None:
|
||||
@@ -51,5 +81,6 @@ class BraintrustEvalProvider(EvalProvider):
|
||||
scores=[],
|
||||
metadata={**configuration.model_dump()},
|
||||
max_concurrency=BRAINTRUST_MAX_CONCURRENCY,
|
||||
no_send_logs=no_send_logs,
|
||||
)
|
||||
return EvalationAck(success=True)
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import braintrust
|
||||
from agents import set_trace_processors
|
||||
from braintrust import init_logger
|
||||
from braintrust.wrappers.openai import BraintrustTracingProcessor
|
||||
from braintrust_langchain import set_global_handler
|
||||
from braintrust_langchain.callbacks import BraintrustCallbackHandler
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import BRAINTRUST_PROJECT
|
||||
|
||||
MASKING_LENGTH = 20000
|
||||
MASKING_LENGTH = int(os.environ.get("BRAINTRUST_MASKING_LENGTH", "20000"))
|
||||
|
||||
|
||||
def _truncate_str(s: str) -> str:
|
||||
@@ -33,3 +37,4 @@ def setup_braintrust() -> None:
|
||||
braintrust.set_masking_function(_mask)
|
||||
handler = BraintrustCallbackHandler()
|
||||
set_global_handler(handler)
|
||||
set_trace_processors([BraintrustTracingProcessor(init_logger(BRAINTRUST_PROJECT))])
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Literal
|
||||
|
||||
from braintrust import traced
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -34,29 +33,30 @@ class LLMConfig(BaseModel):
|
||||
|
||||
|
||||
def log_prompt(prompt: LanguageModelInput) -> None:
|
||||
if isinstance(prompt, list):
|
||||
for ind, msg in enumerate(prompt):
|
||||
if isinstance(msg, AIMessageChunk):
|
||||
if msg.content:
|
||||
log_msg = msg.content
|
||||
elif msg.tool_call_chunks:
|
||||
log_msg = "Tool Calls: " + str(
|
||||
[
|
||||
{
|
||||
key: value
|
||||
for key, value in tool_call.items()
|
||||
if key != "index"
|
||||
}
|
||||
for tool_call in msg.tool_call_chunks
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = ""
|
||||
logger.debug(f"Message {ind}:\n{log_msg}")
|
||||
else:
|
||||
logger.debug(f"Message {ind}:\n{msg.content}")
|
||||
if isinstance(prompt, str):
|
||||
logger.debug(f"Prompt:\n{prompt}")
|
||||
# if isinstance(prompt, list):
|
||||
# for ind, msg in enumerate(prompt):
|
||||
# if isinstance(msg, AIMessageChunk):
|
||||
# if msg.content:
|
||||
# log_msg = msg.content
|
||||
# elif msg.tool_call_chunks:
|
||||
# log_msg = "Tool Calls: " + str(
|
||||
# [
|
||||
# {
|
||||
# key: value
|
||||
# for key, value in tool_call.items()
|
||||
# if key != "index"
|
||||
# }
|
||||
# for tool_call in msg.tool_call_chunks
|
||||
# ]
|
||||
# )
|
||||
# else:
|
||||
# log_msg = ""
|
||||
# logger.debug(f"Message {ind}:\n{log_msg}")
|
||||
# else:
|
||||
# logger.debug(f"Message {ind}:\n{msg.content}")
|
||||
# if isinstance(prompt, str):
|
||||
# logger.debug(f"Prompt:\n{prompt}")
|
||||
pass
|
||||
|
||||
|
||||
class LLM(abc.ABC):
|
||||
|
||||
@@ -1160,7 +1160,7 @@ ANSWER:
|
||||
|
||||
|
||||
GET_CLARIFICATION_PROMPT = PromptTemplate(
|
||||
f"""\
|
||||
"""\
|
||||
You are great at asking clarifying questions in case \
|
||||
a base question is not as clear enough. Your task is to ask necessary clarification \
|
||||
questions to the user, before the question is sent to the deep research agent.
|
||||
@@ -1183,17 +1183,6 @@ In case the knowledge graph is used, here is the description of the entity and r
|
||||
The tools and the entity and relationship types in the knowledge graph are simply provided \
|
||||
as context for determining whether the question requires clarification.
|
||||
|
||||
Here is the question the user asked:
|
||||
{SEPARATOR_LINE}
|
||||
---question---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
Here is the previous chat history (if any), which may contain relevant information \
|
||||
to answer the question:
|
||||
{SEPARATOR_LINE}
|
||||
---chat_history_string---
|
||||
{SEPARATOR_LINE}
|
||||
|
||||
NOTES:
|
||||
- you have to reason over this purely based on your intrinsic knowledge.
|
||||
- if clarifications are required, fill in 'true' for the "feedback_needed" field and \
|
||||
|
||||
@@ -21,7 +21,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import extract_headers
|
||||
from onyx.chat.process_message import stream_chat_message
|
||||
from onyx.chat.prompt_builder.citations_prompt import (
|
||||
@@ -63,13 +62,8 @@ from onyx.db.user_documents import create_user_files
|
||||
from onyx.file_processing.extract_file_text import docx_to_txt_filename
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.secondary_llm_flows.chat_session_naming import (
|
||||
get_renamed_conversation_name,
|
||||
)
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
@@ -305,45 +299,44 @@ def rename_chat_session(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RenameChatSessionResponse:
|
||||
name = rename_req.name
|
||||
chat_session_id = rename_req.chat_session_id
|
||||
user_id = user.id if user is not None else None
|
||||
# name = rename_req.name
|
||||
# chat_session_id = rename_req.chat_session_id
|
||||
# user_id = user.id if user is not None else None
|
||||
|
||||
if name:
|
||||
update_chat_session(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
chat_session_id=chat_session_id,
|
||||
description=name,
|
||||
)
|
||||
return RenameChatSessionResponse(new_name=name)
|
||||
# if name:
|
||||
# update_chat_session(
|
||||
# db_session=db_session,
|
||||
# user_id=user_id,
|
||||
# chat_session_id=chat_session_id,
|
||||
# description=name,
|
||||
# )
|
||||
# return RenameChatSessionResponse(new_name=name)
|
||||
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
full_history = history_msgs + [final_msg]
|
||||
# final_msg, history_msgs = create_chat_chain(
|
||||
# chat_session_id=chat_session_id, db_session=db_session
|
||||
# )
|
||||
# full_history = history_msgs + [final_msg]
|
||||
|
||||
try:
|
||||
llm, _ = get_default_llms(
|
||||
additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
)
|
||||
)
|
||||
except GenAIDisabledException:
|
||||
# This may be longer than what the LLM tends to produce but is the most
|
||||
# clear thing we can do
|
||||
return RenameChatSessionResponse(new_name=full_history[0].message)
|
||||
# try:
|
||||
# llm, _ = get_default_llms(
|
||||
# additional_headers=extract_headers(
|
||||
# request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
# )
|
||||
# )
|
||||
# except GenAIDisabledException:
|
||||
# # This may be longer than what the LLM tends to produce but is the most
|
||||
# # clear thing we can do
|
||||
# return RenameChatSessionResponse(new_name=full_history[0].message)
|
||||
|
||||
new_name = get_renamed_conversation_name(full_history=full_history, llm=llm)
|
||||
# new_name = get_renamed_conversation_name(full_history=full_history, llm=llm)
|
||||
|
||||
update_chat_session(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
chat_session_id=chat_session_id,
|
||||
description=new_name,
|
||||
)
|
||||
|
||||
return RenameChatSessionResponse(new_name=new_name)
|
||||
# update_chat_session(
|
||||
# db_session=db_session,
|
||||
# user_id=user_id,
|
||||
# chat_session_id=chat_session_id,
|
||||
# description=new_name,
|
||||
# )
|
||||
return RenameChatSessionResponse(new_name="hi")
|
||||
|
||||
|
||||
@router.patch("/chat-session/{session_id}")
|
||||
|
||||
Reference in New Issue
Block a user