Compare commits

...

4 Commits

Author SHA1 Message Date
Evan Lohn
022e2f0a24 added utils file 2025-02-10 11:05:40 -08:00
Evan Lohn
7afb390256 fixed unit tests 2025-02-10 11:05:07 -08:00
Evan Lohn
8ebb08df09 anthropic tool calling fix 2025-02-10 10:19:55 -08:00
Evan Lohn
02148670e2 k 2025-02-06 19:57:14 -08:00
13 changed files with 153 additions and 74 deletions

View File

@@ -3,7 +3,6 @@ from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
basic_use_tool_response,
@@ -13,6 +12,8 @@ from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.configs.agent_configs import AGENT_MAX_TOOL_CALLS
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -22,7 +23,7 @@ def basic_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=BasicState,
input=BasicInput,
output=BasicOutput,
output=ToolChoiceUpdate,
)
### Add nodes ###
@@ -60,11 +61,15 @@ def basic_graph_builder() -> StateGraph:
end_key="basic_use_tool_response",
)
graph.add_edge(
start_key="basic_use_tool_response",
end_key=END,
graph.add_conditional_edges(
"basic_use_tool_response", should_continue, ["tool_call", END]
)
# graph.add_edge(
# start_key="basic_use_tool_response",
# end_key=END,
# )
return graph
@@ -72,7 +77,8 @@ def should_continue(state: BasicState) -> str:
return (
# If there are no tool calls, basic graph already streamed the answer
END
if state.tool_choice is None
if state.tool_choices[-1] is None
or len(state.tool_choices) > AGENT_MAX_TOOL_CALLS
else "tool_call"
)

View File

@@ -30,11 +30,12 @@ def route_initial_tool_choice(
LangGraph edge to route to agent search.
"""
agent_config = cast(GraphConfig, config["metadata"]["config"])
if state.tool_choice is not None:
if state.tool_choices[-1] is not None:
if (
agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and state.tool_choice.tool.name == agent_config.tooling.search_tool.name
and state.tool_choices[-1].tool.name
== agent_config.tooling.search_tool.name
):
return "start_agent_search"
else:

View File

@@ -4,10 +4,11 @@ from langchain_core.messages import AIMessageChunk
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.orchestration.utils import get_tool_choice_update
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContexts
from onyx.tools.tool_implementations.search.search_tool import (
@@ -23,11 +24,15 @@ logger = setup_logger()
def basic_use_tool_response(
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BasicOutput:
) -> ToolChoiceUpdate:
agent_config = cast(GraphConfig, config["metadata"]["config"])
structured_response_format = agent_config.inputs.structured_response_format
llm = agent_config.tooling.primary_llm
tool_choice = state.tool_choice
assert (
len(state.tool_choices) > 0
), "Tool choice node must have at least one tool choice"
tool_choice = state.tool_choices[-1]
if tool_choice is None:
raise ValueError("Tool choice is None")
tool = tool_choice.tool
@@ -61,6 +66,8 @@ def basic_use_tool_response(
stream = llm.stream(
prompt=new_prompt_builder.build(),
structured_response_format=structured_response_format,
tools=[_tool.tool_definition() for _tool in agent_config.tooling.tools],
tool_choice=None,
)
# For now, we don't do multiple tool calls, so we ignore the tool_message
@@ -74,4 +81,4 @@ def basic_use_tool_response(
displayed_search_results=initial_search_results or final_search_results,
)
return BasicOutput(tool_call_chunk=new_tool_call_chunk)
return get_tool_choice_update(new_tool_call_chunk, agent_config.tooling.tools)

View File

@@ -1,21 +1,21 @@
from typing import cast
from uuid import uuid4
from langchain_core.messages import ToolCall
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolChoice
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.orchestration.utils import get_tool_choice_update
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.tools.tool import Tool
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -26,7 +26,7 @@ logger = setup_logger()
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
def llm_tool_choice(
state: ToolChoiceState,
state: BasicState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ToolChoiceUpdate:
@@ -72,11 +72,13 @@ def llm_tool_choice(
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
),
tool_choices=[
ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
)
],
)
# if we're skipping gen ai answer generation, we should only
@@ -84,7 +86,7 @@ def llm_tool_choice(
# the tool calling llm in the stream() below)
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
return ToolChoiceUpdate(
tool_choice=None,
tool_choices=[None],
)
built_prompt = (
@@ -99,7 +101,9 @@ def llm_tool_choice(
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=built_prompt,
tools=[tool.tool_definition() for tool in tools] or None,
tool_choice=("required" if tools and force_use_tool.force_use else None),
tool_choice=(
ToolChoiceOptions.REQUIRED if tools and force_use_tool.force_use else None
),
structured_response_format=structured_response_format,
)
@@ -110,45 +114,4 @@ def llm_tool_choice(
writer,
)
# If no tool calls are emitted by the LLM, we should not choose a tool
if len(tool_message.tool_calls) == 0:
logger.debug("No tool calls emitted by LLM")
return ToolChoiceUpdate(
tool_choice=None,
)
# TODO: here we could handle parallel tool calls. Right now
# we just pick the first one that matches.
selected_tool: Tool | None = None
selected_tool_call_request: ToolCall | None = None
for tool_call_request in tool_message.tool_calls:
known_tools_by_name = [
tool for tool in tools if tool.name == tool_call_request["name"]
]
if known_tools_by_name:
selected_tool = known_tools_by_name[0]
selected_tool_call_request = tool_call_request
break
logger.error(
"Tool call requested with unknown name field. \n"
f"tools: {tools}"
f"tool_call_request: {tool_call_request}"
)
if not selected_tool or not selected_tool_call_request:
raise ValueError(
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
)
logger.debug(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
),
)
return get_tool_choice_update(tool_message, tools)

View File

@@ -37,7 +37,10 @@ def tool_call(
cast(GraphConfig, config["metadata"]["config"])
tool_choice = state.tool_choice
assert (
len(state.tool_choices) > 0
), "Tool call node must have at least one tool choice"
tool_choice = state.tool_choices[-1]
if tool_choice is None:
raise ValueError("Cannot invoke tool call node without a tool choice")

View File

@@ -1,3 +1,6 @@
from operator import add
from typing import Annotated
from pydantic import BaseModel
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
@@ -41,7 +44,7 @@ class ToolChoice(BaseModel):
class ToolChoiceUpdate(BaseModel):
tool_choice: ToolChoice | None = None
tool_choices: Annotated[list[ToolChoice | None], add] = []
class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):

View File

@@ -0,0 +1,58 @@
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import ToolCall
from onyx.agents.agent_search.orchestration.states import ToolChoice
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.tools.tool import Tool
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_tool_choice_update(
tool_message: AIMessageChunk, tools: list[Tool]
) -> ToolChoiceUpdate:
# If no tool calls are emitted by the LLM, we should not choose a tool
if len(tool_message.tool_calls) == 0:
logger.debug("No tool calls emitted by LLM")
return ToolChoiceUpdate(
tool_choices=[None],
)
# TODO: here we could handle parallel tool calls. Right now
# we just pick the first one that matches.
selected_tool: Tool | None = None
selected_tool_call_request: ToolCall | None = None
for tool_call_request in tool_message.tool_calls:
known_tools_by_name = [
tool for tool in tools if tool.name == tool_call_request["name"]
]
if known_tools_by_name:
selected_tool = known_tools_by_name[0]
selected_tool_call_request = tool_call_request
break
logger.error(
"Tool call requested with unknown name field. \n"
f"tools: {tools}"
f"tool_call_request: {tool_call_request}"
)
if not selected_tool or not selected_tool_call_request:
raise ValueError(
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
)
logger.debug(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
return ToolChoiceUpdate(
tool_choices=[
ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
)
],
)

View File

@@ -19,6 +19,7 @@ from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.llm.utils import model_supports_image_input
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.chat_prompts import NO_TOOL_CALL_PREAMBLE
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import drop_messages_history_overflow
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
@@ -27,6 +28,7 @@ from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.utils import is_anthropic_tool_calling_model
def default_build_system_message(
@@ -138,6 +140,14 @@ class AnswerPromptBuilder:
self.system_message_and_token_cnt = None
return
if is_anthropic_tool_calling_model(
self.llm_config.model_provider, self.llm_config.model_name
):
if isinstance(system_message.content, str):
system_message.content += NO_TOOL_CALL_PREAMBLE
else:
system_message.content.append(NO_TOOL_CALL_PREAMBLE)
self.system_message_and_token_cnt = (
system_message,
check_message_tokens(system_message, self.llm_tokenizer_encode_func),

View File

@@ -12,7 +12,7 @@ AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS = 5
AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
AGENT_DEFAULT_MAX_TOOL_CALLS = 3
#####
# Agent Configs
#####
@@ -77,4 +77,8 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
or AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH
) # 2000
AGENT_MAX_TOOL_CALLS = int(
os.environ.get("AGENT_MAX_TOOL_CALLS") or AGENT_DEFAULT_MAX_TOOL_CALLS
) # 1
GRAPH_VERSION_NAME: str = "a"

View File

@@ -1,6 +1,6 @@
import abc
from collections.abc import Iterator
from typing import Literal
from enum import Enum
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import AIMessageChunk
@@ -15,7 +15,11 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
ToolChoiceOptions = Literal["required"] | Literal["auto"] | Literal["none"]
class ToolChoiceOptions(Enum):
REQUIRED = "required"
AUTO = "auto"
NONE = "none"
class LLMConfig(BaseModel):

View File

@@ -41,6 +41,13 @@ CHAT_USER_CONTEXT_FREE_PROMPT = f"""
{{user_query}}
""".strip()
# we tried telling anthropic to not make repeated tool calls, but it didn't work very well.
# when anthropic models don't follow this convention, it leads to the user seeing "the model
# decided not to search" for a second, which isn't great UX.
NO_TOOL_CALL_PREAMBLE = (
"\nThe first time you call a tool, call it IMMEDIATELY without a textual preamble."
)
# Design considerations for the below:
# - In case of uncertainty, favor yes search so place the "yes" sections near the start of the

View File

@@ -6,6 +6,8 @@ from onyx.configs.app_configs import AZURE_DALLE_API_KEY
from onyx.db.connector import check_connectors_exist
from onyx.db.document import check_docs_exist
from onyx.db.models import LLMProvider
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.tools.tool import Tool
@@ -18,9 +20,20 @@ OPEN_AI_TOOL_CALLING_MODELS = {
"gpt-4o-mini",
}
ANTHROPIC_TOOL_CALLING_PREFIX = "claude-3-5-sonnet"
def is_anthropic_tool_calling_model(model_provider: str, model_name: str) -> bool:
return model_provider == ANTHROPIC_PROVIDER_NAME and model_name.startswith(
ANTHROPIC_TOOL_CALLING_PREFIX
)
def explicit_tool_calling_supported(model_provider: str, model_name: str) -> bool:
return model_provider == "openai" and model_name in OPEN_AI_TOOL_CALLING_MODELS
return (
model_provider == OPENAI_PROVIDER_NAME
and model_name in OPEN_AI_TOOL_CALLING_MODELS
) or is_anthropic_tool_calling_model(model_provider, model_name)
def compute_tool_tokens(tool: Tool, llm_tokenizer: BaseTokenizer) -> int:

View File

@@ -229,7 +229,7 @@ def test_answer_with_search_call(
)
# Second call should not include tools (as we're just generating the final answer)
assert "tools" not in second_call.kwargs or not second_call.kwargs["tools"]
# assert "tools" not in second_call.kwargs or not second_call.kwargs["tools"]
# Second call should use the returned prompt from build_next_prompt
assert (
second_call.kwargs["prompt"]
@@ -237,7 +237,7 @@ def test_answer_with_search_call(
)
# Verify that tool_definition was called on the mock_search_tool
mock_search_tool.tool_definition.assert_called_once()
assert mock_search_tool.tool_definition.call_count == 2
else:
assert mock_llm.stream.call_count == 1
@@ -310,7 +310,7 @@ def test_answer_with_search_no_tool_calling(
call_args = mock_llm.stream.call_args
# Verify that no tools were passed to the LLM
assert "tools" not in call_args.kwargs or not call_args.kwargs["tools"]
# assert "tools" not in call_args.kwargs or not call_args.kwargs["tools"]
# Verify that the prompt was built correctly
assert (