mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-04 14:32:41 +00:00
Compare commits
11 Commits
cli/v0.2.1
...
eric/agent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c112e29822 | ||
|
|
8210a5c46a | ||
|
|
1d3258d762 | ||
|
|
ea895ddac2 | ||
|
|
0f5f76563b | ||
|
|
e8e187beb9 | ||
|
|
c3c28b956c | ||
|
|
01d2d8e138 | ||
|
|
7af8e4e874 | ||
|
|
a9f5855ff4 | ||
|
|
6723bc7632 |
72
backend/alembic/versions/6c5e6b7fbbab_agent_to_agent.py
Normal file
72
backend/alembic/versions/6c5e6b7fbbab_agent_to_agent.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""agent to agent
|
||||
|
||||
Revision ID: 6c5e6b7fbbab
|
||||
Revises: e8f0d2a38171
|
||||
Create Date: 2025-12-05 10:56:43.190279
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6c5e6b7fbbab"
|
||||
down_revision = "e8f0d2a38171"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"persona__persona",
|
||||
sa.Column("parent_persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("child_persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"pass_conversation_context",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"pass_files",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("max_tokens_to_child", sa.Integer(), nullable=True),
|
||||
sa.Column("max_tokens_from_child", sa.Integer(), nullable=True),
|
||||
sa.Column("invocation_instructions", sa.Text(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["parent_persona_id"],
|
||||
["persona.id"],
|
||||
ondelete="CASCADE",
|
||||
name="fk_persona__persona_parent_persona_id_persona",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["child_persona_id"],
|
||||
["persona.id"],
|
||||
ondelete="CASCADE",
|
||||
name="fk_persona__persona_child_persona_id_persona",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("parent_persona_id", "child_persona_id"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("invoked_persona_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_tool_call_invoked_persona",
|
||||
"tool_call",
|
||||
"persona",
|
||||
["invoked_persona_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("fk_tool_call_invoked_persona", "tool_call", type_="foreignkey")
|
||||
op.drop_column("tool_call", "invoked_persona_id")
|
||||
op.drop_table("persona__persona")
|
||||
@@ -1015,6 +1015,23 @@ def run_llm_loop(
|
||||
):
|
||||
generated_images = tool_response.rich_response.generated_images
|
||||
|
||||
# Check if this is an agent tool to get the invoked persona id and collected data
|
||||
invoked_persona_id = None
|
||||
agent_search_docs = search_docs
|
||||
if hasattr(tool, "_child_persona"):
|
||||
invoked_persona_id = tool._child_persona.id
|
||||
# Include collected search queries and docs from child agent's tools
|
||||
tool_args = dict(tool_call.tool_args)
|
||||
if (
|
||||
hasattr(tool, "_nested_agent_runs")
|
||||
and tool._nested_agent_runs
|
||||
):
|
||||
tool_args["_agent_nested_runs"] = tool._nested_agent_runs
|
||||
# Do not attach search docs to the agent tool call itself; child agents own their docs
|
||||
agent_search_docs = None
|
||||
else:
|
||||
tool_args = tool_call.tool_args
|
||||
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
|
||||
turn_index=current_tool_call_index,
|
||||
@@ -1022,10 +1039,11 @@ def run_llm_loop(
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
tool_id=tool.id,
|
||||
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
|
||||
tool_call_arguments=tool_call.tool_args,
|
||||
tool_call_arguments=tool_args,
|
||||
tool_call_response=tool_response.llm_facing_response,
|
||||
search_docs=search_docs,
|
||||
search_docs=agent_search_docs,
|
||||
generated_images=generated_images,
|
||||
invoked_persona_id=invoked_persona_id,
|
||||
)
|
||||
collected_tool_calls.append(tool_call_info)
|
||||
# Add to state container for partial save support
|
||||
|
||||
@@ -102,6 +102,7 @@ def _create_and_link_tool_calls(
|
||||
if tool_call_info.generated_images
|
||||
else None
|
||||
),
|
||||
invoked_persona_id=tool_call_info.invoked_persona_id,
|
||||
add_only=True,
|
||||
)
|
||||
|
||||
@@ -216,8 +217,9 @@ def save_chat_turn(
|
||||
search_doc_key_to_id[search_doc_key] = db_search_doc.id
|
||||
search_doc_ids_for_tool.append(db_search_doc.id)
|
||||
|
||||
unique_search_doc_ids = list(dict.fromkeys(search_doc_ids_for_tool))
|
||||
tool_call_to_search_doc_ids[tool_call_info.tool_call_id] = (
|
||||
search_doc_ids_for_tool
|
||||
unique_search_doc_ids
|
||||
)
|
||||
|
||||
# 3. Collect all unique SearchDoc IDs from all tool calls to link to ChatMessage
|
||||
|
||||
@@ -101,3 +101,5 @@ USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH = (
|
||||
)
|
||||
|
||||
USE_DIV_CON_AGENT = os.environ.get("USE_DIV_CON_AGENT", "false").lower() == "true"
|
||||
|
||||
MAX_AGENT_RECURSION_DEPTH = int(os.environ.get("MAX_AGENT_RECURSION_DEPTH", "3"))
|
||||
|
||||
@@ -2234,9 +2234,14 @@ class ToolCall(Base):
|
||||
generated_images: Mapped[list[dict] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
invoked_persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
# Relationships
|
||||
chat_session: Mapped[ChatSession] = relationship("ChatSession")
|
||||
invoked_persona: Mapped["Persona | None"] = relationship(
|
||||
"Persona", foreign_keys=[invoked_persona_id]
|
||||
)
|
||||
|
||||
chat_message: Mapped["ChatMessage | None"] = relationship(
|
||||
"ChatMessage",
|
||||
@@ -2714,6 +2719,22 @@ class StarterMessage(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class Persona__Persona(Base):
|
||||
__tablename__ = "persona__persona"
|
||||
|
||||
parent_persona_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("persona.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
child_persona_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("persona.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
pass_conversation_context: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
pass_files: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
max_tokens_to_child: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
max_tokens_from_child: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
invocation_instructions: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
|
||||
class Persona__PersonaLabel(Base):
|
||||
__tablename__ = "persona__persona_label"
|
||||
|
||||
@@ -2837,8 +2858,14 @@ class Persona(Base):
|
||||
secondary=Persona__PersonaLabel.__table__,
|
||||
back_populates="personas",
|
||||
)
|
||||
child_personas: Mapped[list["Persona"]] = relationship(
|
||||
"Persona",
|
||||
secondary=Persona__Persona.__table__,
|
||||
primaryjoin="Persona.id == Persona__Persona.parent_persona_id",
|
||||
secondaryjoin="Persona.id == Persona__Persona.child_persona_id",
|
||||
foreign_keys="[Persona__Persona.parent_persona_id, Persona__Persona.child_persona_id]",
|
||||
)
|
||||
|
||||
# Default personas loaded via yaml cannot have the same name
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"_builtin_persona_name_idx",
|
||||
|
||||
@@ -26,6 +26,7 @@ from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import PersonaLabel
|
||||
@@ -36,6 +37,7 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.server.features.persona.models import ChildPersonaConfig
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
from onyx.server.features.persona.models import MinimalPersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
@@ -286,7 +288,6 @@ def create_update_persona(
|
||||
"onyx.db.persona", "make_persona_private"
|
||||
)
|
||||
|
||||
# Privatize Persona
|
||||
versioned_make_persona_private(
|
||||
persona_id=persona.id,
|
||||
creator_user_id=user.id if user else None,
|
||||
@@ -295,6 +296,14 @@ def create_update_persona(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if create_persona_request.child_persona_ids is not None:
|
||||
update_persona_child_personas(
|
||||
persona_id=persona.id,
|
||||
child_persona_ids=create_persona_request.child_persona_ids,
|
||||
child_persona_configs=create_persona_request.child_persona_configs,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to create persona")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -348,6 +357,86 @@ def update_persona_public_status(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_child_personas(
|
||||
persona_id: int,
|
||||
child_persona_ids: list[int],
|
||||
child_persona_configs: list | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
from onyx.server.features.persona.models import ChildPersonaConfig
|
||||
|
||||
db_session.query(Persona__Persona).filter(
|
||||
Persona__Persona.parent_persona_id == persona_id
|
||||
).delete()
|
||||
|
||||
config_map: dict[int, ChildPersonaConfig] = {}
|
||||
if child_persona_configs:
|
||||
for cfg in child_persona_configs:
|
||||
if isinstance(cfg, dict):
|
||||
cfg = ChildPersonaConfig(**cfg)
|
||||
config_map[cfg.persona_id] = cfg
|
||||
|
||||
for child_id in child_persona_ids:
|
||||
if child_id == persona_id:
|
||||
continue
|
||||
|
||||
child_persona = db_session.query(Persona).filter(Persona.id == child_id).first()
|
||||
if not child_persona or child_persona.deleted:
|
||||
continue
|
||||
|
||||
config = config_map.get(child_id)
|
||||
|
||||
new_relation = Persona__Persona(
|
||||
parent_persona_id=persona_id,
|
||||
child_persona_id=child_id,
|
||||
pass_conversation_context=(
|
||||
config.pass_conversation_context if config else True
|
||||
),
|
||||
pass_files=config.pass_files if config else False,
|
||||
max_tokens_to_child=config.max_tokens_to_child if config else None,
|
||||
max_tokens_from_child=config.max_tokens_from_child if config else None,
|
||||
invocation_instructions=config.invocation_instructions if config else None,
|
||||
)
|
||||
db_session.add(new_relation)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_child_personas(
|
||||
persona_id: int,
|
||||
db_session: Session,
|
||||
) -> list[Persona]:
|
||||
stmt = (
|
||||
select(Persona)
|
||||
.join(Persona__Persona, Persona.id == Persona__Persona.child_persona_id)
|
||||
.where(Persona__Persona.parent_persona_id == persona_id)
|
||||
.where(Persona.deleted.is_(False))
|
||||
)
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def get_child_persona_configs(
|
||||
persona_id: int,
|
||||
db_session: Session,
|
||||
) -> list[ChildPersonaConfig]:
|
||||
"""Get the configuration for all child personas of a given persona."""
|
||||
stmt = select(Persona__Persona).where(
|
||||
Persona__Persona.parent_persona_id == persona_id
|
||||
)
|
||||
links = db_session.scalars(stmt).all()
|
||||
return [
|
||||
ChildPersonaConfig(
|
||||
persona_id=link.child_persona_id,
|
||||
pass_conversation_context=link.pass_conversation_context,
|
||||
pass_files=link.pass_files,
|
||||
max_tokens_to_child=link.max_tokens_to_child,
|
||||
max_tokens_from_child=link.max_tokens_from_child,
|
||||
invocation_instructions=link.invocation_instructions,
|
||||
)
|
||||
for link in links
|
||||
]
|
||||
|
||||
|
||||
def _build_persona_filters(
|
||||
stmt: Select[tuple[Persona]],
|
||||
include_default: bool,
|
||||
|
||||
@@ -221,6 +221,7 @@ def create_tool_call_no_commit(
|
||||
parent_tool_call_id: int | None = None,
|
||||
reasoning_tokens: str | None = None,
|
||||
generated_images: list[dict] | None = None,
|
||||
invoked_persona_id: int | None = None,
|
||||
add_only: bool = True,
|
||||
) -> ToolCall:
|
||||
"""
|
||||
@@ -256,6 +257,7 @@ def create_tool_call_no_commit(
|
||||
tool_call_response=tool_call_response,
|
||||
tool_call_tokens=tool_call_tokens,
|
||||
generated_images=generated_images,
|
||||
invoked_persona_id=invoked_persona_id,
|
||||
)
|
||||
|
||||
db_session.add(tool_call)
|
||||
|
||||
@@ -79,6 +79,38 @@ class LLMRateLimitError(Exception):
|
||||
"""
|
||||
|
||||
|
||||
def _convert_tools_to_responses_api_format(tools: list[dict]) -> list[dict]:
|
||||
"""Convert tools from Chat Completions API format to Responses API format.
|
||||
|
||||
Chat Completions API format:
|
||||
{"type": "function", "function": {"name": "...", "description": "...", "parameters": {...}}}
|
||||
|
||||
Responses API format:
|
||||
{"type": "function", "name": "...", "description": "...", "parameters": {...}}
|
||||
"""
|
||||
converted_tools = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function" and "function" in tool:
|
||||
func = tool["function"]
|
||||
name = func.get("name")
|
||||
if not name:
|
||||
logger.warning("Skipping tool with missing name in function definition")
|
||||
continue
|
||||
converted_tool = {
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
if "strict" in func:
|
||||
converted_tool["strict"] = func["strict"]
|
||||
converted_tools.append(converted_tool)
|
||||
else:
|
||||
# If already in correct format or unknown format, pass through
|
||||
converted_tools.append(tool)
|
||||
return converted_tools
|
||||
|
||||
|
||||
def _base_msg_to_role(msg: BaseMessage) -> str:
|
||||
if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk):
|
||||
return "user"
|
||||
@@ -481,14 +513,22 @@ class LitellmLLM(LLM):
|
||||
)
|
||||
|
||||
# Needed to get reasoning tokens from the model
|
||||
if not is_legacy_langchain and (
|
||||
use_responses_api = not is_legacy_langchain and (
|
||||
is_true_openai_model(self.config.model_provider, self.config.model_name)
|
||||
or self.config.model_provider == AZURE_PROVIDER_NAME
|
||||
):
|
||||
)
|
||||
if use_responses_api:
|
||||
model_provider = f"{self.config.model_provider}/responses"
|
||||
else:
|
||||
model_provider = self.config.model_provider
|
||||
|
||||
# Convert tools to Responses API format if using that API
|
||||
processed_tools = (
|
||||
_convert_tools_to_responses_api_format(tools)
|
||||
if use_responses_api and tools
|
||||
else tools
|
||||
)
|
||||
|
||||
try:
|
||||
return litellm.completion(
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
@@ -503,14 +543,19 @@ class LitellmLLM(LLM):
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
# actual input
|
||||
messages=processed_prompt,
|
||||
tools=tools,
|
||||
tools=processed_tools,
|
||||
tool_choice=tool_choice_formatted,
|
||||
# streaming choice
|
||||
stream=stream,
|
||||
# model params
|
||||
temperature=(1 if is_reasoning else self._temperature),
|
||||
timeout=timeout_override or self._timeout,
|
||||
**({"stream_options": {"include_usage": True}} if stream else {}),
|
||||
# stream_options is not supported by the Responses API
|
||||
**(
|
||||
{"stream_options": {"include_usage": True}}
|
||||
if stream and not use_responses_api
|
||||
else {}
|
||||
),
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
|
||||
@@ -478,6 +478,8 @@ def get_persona(
|
||||
user: User | None = Depends(current_limited_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FullPersonaSnapshot:
|
||||
from onyx.db.persona import get_child_persona_configs
|
||||
|
||||
persona = get_persona_by_id(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
@@ -496,7 +498,8 @@ def get_persona(
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
return FullPersonaSnapshot.from_model(persona)
|
||||
child_configs = get_child_persona_configs(persona_id, db_session)
|
||||
return FullPersonaSnapshot.from_model(persona, child_persona_configs=child_configs)
|
||||
|
||||
|
||||
@basic_router.post("/assistant-prompt-refresh")
|
||||
|
||||
@@ -52,6 +52,15 @@ class GenerateStarterMessageRequest(BaseModel):
|
||||
generation_count: int
|
||||
|
||||
|
||||
class ChildPersonaConfig(BaseModel):
|
||||
persona_id: int
|
||||
pass_conversation_context: bool = True
|
||||
pass_files: bool = False
|
||||
max_tokens_to_child: int | None = None
|
||||
max_tokens_from_child: int | None = None
|
||||
invocation_instructions: str | None = None
|
||||
|
||||
|
||||
class PersonaUpsertRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
@@ -64,24 +73,20 @@ class PersonaUpsertRequest(BaseModel):
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
starter_messages: list[StarterMessage] | None = None
|
||||
# For Private Personas, who should be able to access these
|
||||
users: list[UUID] = Field(default_factory=list)
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
# e.g. ID of SearchTool or ImageGenerationTool or <USER_DEFINED_TOOL>
|
||||
tool_ids: list[int]
|
||||
remove_image: bool | None = None
|
||||
uploaded_image_id: str | None = None # New field for uploaded image
|
||||
icon_name: str | None = (
|
||||
None # New field that is custom chosen during agent creation/editing
|
||||
)
|
||||
uploaded_image_id: str | None = None
|
||||
icon_name: str | None = None
|
||||
search_start_date: datetime | None = None
|
||||
label_ids: list[int] | None = None
|
||||
is_default_persona: bool = False
|
||||
display_priority: int | None = None
|
||||
# Accept string UUIDs from frontend
|
||||
user_file_ids: list[str] | None = None
|
||||
child_persona_ids: list[int] = Field(default_factory=list)
|
||||
child_persona_configs: list[ChildPersonaConfig] | None = None
|
||||
|
||||
# prompt fields
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
datetime_aware: bool
|
||||
@@ -158,6 +163,14 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ChildPersonaSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
uploaded_image_id: str | None = None
|
||||
icon_name: str | None = None
|
||||
|
||||
|
||||
class PersonaSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
@@ -166,7 +179,6 @@ class PersonaSnapshot(BaseModel):
|
||||
is_visible: bool
|
||||
uploaded_image_id: str | None
|
||||
icon_name: str | None
|
||||
# Return string UUIDs to frontend for consistency
|
||||
user_file_ids: list[str]
|
||||
display_priority: int | None
|
||||
is_default_persona: bool
|
||||
@@ -183,14 +195,33 @@ class PersonaSnapshot(BaseModel):
|
||||
llm_model_provider_override: str | None
|
||||
llm_model_version_override: str | None
|
||||
num_chunks: float | None
|
||||
child_personas: list[ChildPersonaSnapshot] = []
|
||||
child_persona_configs: list[ChildPersonaConfig] = []
|
||||
|
||||
# Embedded prompt fields (no longer separate prompt_ids)
|
||||
system_prompt: str | None = None
|
||||
task_prompt: str | None = None
|
||||
datetime_aware: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, persona: Persona) -> "PersonaSnapshot":
|
||||
def from_model(
|
||||
cls,
|
||||
persona: Persona,
|
||||
allow_deleted: bool = False,
|
||||
child_persona_configs: list[ChildPersonaConfig] | None = None,
|
||||
) -> "PersonaSnapshot":
|
||||
child_persona_list = []
|
||||
if hasattr(persona, "child_personas") and persona.child_personas:
|
||||
child_persona_list = [
|
||||
ChildPersonaSnapshot(
|
||||
id=cp.id,
|
||||
name=cp.name,
|
||||
description=cp.description,
|
||||
uploaded_image_id=cp.uploaded_image_id,
|
||||
icon_name=cp.icon_name,
|
||||
)
|
||||
for cp in persona.child_personas
|
||||
if not cp.deleted
|
||||
]
|
||||
return PersonaSnapshot(
|
||||
id=persona.id,
|
||||
name=persona.name,
|
||||
@@ -229,6 +260,8 @@ class PersonaSnapshot(BaseModel):
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
num_chunks=persona.num_chunks,
|
||||
child_personas=child_persona_list,
|
||||
child_persona_configs=child_persona_configs or [],
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=persona.datetime_aware,
|
||||
@@ -244,7 +277,10 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, persona: Persona, allow_deleted: bool = False
|
||||
cls,
|
||||
persona: Persona,
|
||||
allow_deleted: bool = False,
|
||||
child_persona_configs: list[ChildPersonaConfig] | None = None,
|
||||
) -> "FullPersonaSnapshot":
|
||||
if persona.deleted:
|
||||
error_msg = f"Persona with ID {persona.id} has been deleted"
|
||||
@@ -253,6 +289,20 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
else:
|
||||
logger.warning(error_msg)
|
||||
|
||||
child_persona_list = []
|
||||
if hasattr(persona, "child_personas") and persona.child_personas:
|
||||
child_persona_list = [
|
||||
ChildPersonaSnapshot(
|
||||
id=cp.id,
|
||||
name=cp.name,
|
||||
description=cp.description,
|
||||
uploaded_image_id=cp.uploaded_image_id,
|
||||
icon_name=cp.icon_name,
|
||||
)
|
||||
for cp in persona.child_personas
|
||||
if not cp.deleted
|
||||
]
|
||||
|
||||
return FullPersonaSnapshot(
|
||||
id=persona.id,
|
||||
name=persona.name,
|
||||
@@ -292,6 +342,8 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
child_personas=child_persona_list,
|
||||
child_persona_configs=child_persona_configs or [],
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=persona.datetime_aware,
|
||||
|
||||
@@ -10,9 +10,12 @@ from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.chat import get_db_search_doc_by_id
|
||||
from onyx.db.chat import translate_db_search_doc_to_saved_search_doc
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.tools import get_tool_by_id
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import AgentToolFinal
|
||||
from onyx.server.query_and_chat.streaming_models import AgentToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
@@ -260,6 +263,213 @@ def create_search_packets(
|
||||
return packets
|
||||
|
||||
|
||||
def collect_nested_search_data(
|
||||
tool_call: "ToolCall",
|
||||
db_session: Session,
|
||||
include_children: bool = True,
|
||||
) -> tuple[list[str], list[SavedSearchDoc]]:
|
||||
queries: list[str] = []
|
||||
docs: list[SavedSearchDoc] = []
|
||||
|
||||
if "_agent_search_queries" in tool_call.tool_call_arguments:
|
||||
queries.extend(
|
||||
cast(list[str], tool_call.tool_call_arguments["_agent_search_queries"])
|
||||
)
|
||||
|
||||
if tool_call.search_docs:
|
||||
docs.extend(
|
||||
[
|
||||
translate_db_search_doc_to_saved_search_doc(doc)
|
||||
for doc in tool_call.search_docs
|
||||
]
|
||||
)
|
||||
|
||||
if include_children:
|
||||
for child_tool_call in tool_call.tool_call_children:
|
||||
if child_tool_call.invoked_persona_id is not None:
|
||||
child_queries, child_docs = collect_nested_search_data(
|
||||
child_tool_call, db_session, include_children=True
|
||||
)
|
||||
queries.extend(child_queries)
|
||||
docs.extend(child_docs)
|
||||
else:
|
||||
if child_tool_call.tool_call_arguments.get("queries"):
|
||||
child_queries = cast(
|
||||
list[str], child_tool_call.tool_call_arguments["queries"]
|
||||
)
|
||||
queries.extend(child_queries)
|
||||
if child_tool_call.search_docs:
|
||||
docs.extend(
|
||||
[
|
||||
translate_db_search_doc_to_saved_search_doc(doc)
|
||||
for doc in child_tool_call.search_docs
|
||||
]
|
||||
)
|
||||
|
||||
unique_queries = list(dict.fromkeys(queries))
|
||||
|
||||
seen_doc_ids: set[str] = set()
|
||||
unique_docs: list[SavedSearchDoc] = []
|
||||
for doc in docs:
|
||||
if doc.document_id not in seen_doc_ids:
|
||||
seen_doc_ids.add(doc.document_id)
|
||||
unique_docs.append(doc)
|
||||
|
||||
return unique_queries, unique_docs
|
||||
|
||||
|
||||
def create_agent_tool_packets(
|
||||
agent_name: str,
|
||||
agent_id: int,
|
||||
response: str,
|
||||
turn_index: int,
|
||||
search_queries: list[str] | None = None,
|
||||
search_docs: list[SavedSearchDoc] | None = None,
|
||||
nested_runs: list[dict] | None = None,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentToolStart(agent_name=agent_name, agent_id=agent_id),
|
||||
)
|
||||
)
|
||||
|
||||
# Emit SearchToolStart if we have search queries or docs (needed for frontend to show search section)
|
||||
if search_queries or search_docs:
|
||||
packets.append(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=SearchToolStart(is_internet_search=False),
|
||||
)
|
||||
)
|
||||
|
||||
if search_queries:
|
||||
packets.append(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=SearchToolQueriesDelta(queries=search_queries),
|
||||
)
|
||||
)
|
||||
|
||||
if search_docs:
|
||||
sorted_search_docs = sorted(
|
||||
search_docs, key=lambda x: x.score or 0.0, reverse=True
|
||||
)
|
||||
packets.append(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=SearchToolDocumentsDelta(
|
||||
documents=[
|
||||
SearchDoc(**doc.model_dump()) for doc in sorted_search_docs
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
summary = response[:200] + "..." if len(response) > 200 else response
|
||||
packets.append(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentToolFinal(
|
||||
agent_name=agent_name,
|
||||
summary=summary,
|
||||
full_response=response,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
|
||||
|
||||
if nested_runs:
|
||||
for nested in nested_runs:
|
||||
nested_agent_name = nested.get("agent_name", "")
|
||||
nested_agent_id = nested.get("agent_id", 0)
|
||||
nested_response = nested.get("response", "")
|
||||
nested_queries = nested.get("search_queries") or []
|
||||
nested_docs_raw = nested.get("search_docs") or []
|
||||
nested_nested_runs = nested.get("nested_runs") or None
|
||||
|
||||
nested_docs: list[SavedSearchDoc] = []
|
||||
for doc_dict in nested_docs_raw:
|
||||
try:
|
||||
nested_docs.append(SavedSearchDoc(**doc_dict))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
packets.extend(
|
||||
create_agent_tool_packets(
|
||||
agent_name=nested_agent_name,
|
||||
agent_id=nested_agent_id,
|
||||
response=nested_response,
|
||||
turn_index=turn_index,
|
||||
search_queries=nested_queries if nested_queries else None,
|
||||
search_docs=nested_docs if nested_docs else None,
|
||||
nested_runs=nested_nested_runs,
|
||||
)
|
||||
)
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def reconstruct_nested_agent_tool_call(
|
||||
tool_call: "ToolCall",
|
||||
base_turn_index: int,
|
||||
agent_counter: list[int],
|
||||
db_session: Session,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
if tool_call.invoked_persona_id is None:
|
||||
return packets
|
||||
|
||||
invoked_persona = tool_call.invoked_persona
|
||||
if not invoked_persona:
|
||||
return packets
|
||||
|
||||
# Assign this agent the next sequential turn_index
|
||||
current_turn_index = base_turn_index + agent_counter[0]
|
||||
agent_counter[0] += 1
|
||||
|
||||
# Collect search queries and docs from this agent and its nested children
|
||||
nested_queries, nested_search_docs = collect_nested_search_data(
|
||||
tool_call, db_session, include_children=False
|
||||
)
|
||||
nested_runs = (
|
||||
cast(list[dict] | None, tool_call.tool_call_arguments.get("_agent_nested_runs"))
|
||||
if tool_call.tool_call_arguments
|
||||
else None
|
||||
)
|
||||
|
||||
# Create packets for this agent
|
||||
packets.extend(
|
||||
create_agent_tool_packets(
|
||||
agent_name=invoked_persona.name,
|
||||
agent_id=invoked_persona.id,
|
||||
response=tool_call.tool_call_response,
|
||||
turn_index=current_turn_index,
|
||||
search_queries=nested_queries if nested_queries else None,
|
||||
search_docs=nested_search_docs if nested_search_docs else None,
|
||||
nested_runs=nested_runs,
|
||||
)
|
||||
)
|
||||
|
||||
# Recursively process nested agent tool calls (they'll get the next sequential indices)
|
||||
for child_tool_call in tool_call.tool_call_children:
|
||||
if child_tool_call.invoked_persona_id is not None:
|
||||
# This is a nested agent tool call - recursively process it
|
||||
child_packets = reconstruct_nested_agent_tool_call(
|
||||
child_tool_call,
|
||||
base_turn_index,
|
||||
agent_counter,
|
||||
db_session,
|
||||
)
|
||||
packets.extend(child_packets)
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def translate_assistant_message_to_packets(
|
||||
chat_message: ChatMessage,
|
||||
db_session: Session,
|
||||
@@ -275,9 +485,14 @@ def translate_assistant_message_to_packets(
|
||||
raise ValueError(f"Chat message {chat_message.id} is not an assistant message")
|
||||
|
||||
if chat_message.tool_calls:
|
||||
# Group tool calls by turn_number
|
||||
# Filter to only top-level tool calls (parent_tool_call_id is None)
|
||||
top_level_tool_calls = [
|
||||
tc for tc in chat_message.tool_calls if tc.parent_tool_call_id is None
|
||||
]
|
||||
|
||||
# Group top-level tool calls by turn_number
|
||||
tool_calls_by_turn: dict[int, list] = {}
|
||||
for tool_call in chat_message.tool_calls:
|
||||
for tool_call in top_level_tool_calls:
|
||||
turn_num = tool_call.turn_number
|
||||
if turn_num not in tool_calls_by_turn:
|
||||
tool_calls_by_turn[turn_num] = []
|
||||
@@ -287,9 +502,25 @@ def translate_assistant_message_to_packets(
|
||||
for turn_num in sorted(tool_calls_by_turn.keys()):
|
||||
tool_calls_in_turn = tool_calls_by_turn[turn_num]
|
||||
|
||||
# Use a counter to assign sequential turn_index values for nested agents
|
||||
# This ensures proper ordering: parent agents get lower indices, nested agents get higher indices
|
||||
agent_counter = [0]
|
||||
|
||||
# Process each tool call in this turn
|
||||
for tool_call in tool_calls_in_turn:
|
||||
try:
|
||||
# Handle agent tools specially - they have invoked_persona_id set
|
||||
if tool_call.invoked_persona_id is not None:
|
||||
# Recursively reconstruct this agent and all nested agents
|
||||
agent_packets = reconstruct_nested_agent_tool_call(
|
||||
tool_call,
|
||||
turn_num,
|
||||
agent_counter,
|
||||
db_session,
|
||||
)
|
||||
packet_list.extend(agent_packets)
|
||||
continue
|
||||
|
||||
tool = get_tool_by_id(tool_call.tool_id, db_session)
|
||||
|
||||
# Handle different tool types
|
||||
|
||||
@@ -10,8 +10,6 @@ from onyx.context.search.models import SearchDoc
|
||||
|
||||
|
||||
class StreamingType(Enum):
|
||||
"""Enum defining all streaming packet types. This is the single source of truth for type strings."""
|
||||
|
||||
MESSAGE_START = "message_start"
|
||||
MESSAGE_DELTA = "message_delta"
|
||||
ERROR = "error"
|
||||
@@ -33,6 +31,9 @@ class StreamingType(Enum):
|
||||
REASONING_DELTA = "reasoning_delta"
|
||||
REASONING_DONE = "reasoning_done"
|
||||
CITATION_INFO = "citation_info"
|
||||
AGENT_TOOL_START = "agent_tool_start"
|
||||
AGENT_TOOL_DELTA = "agent_tool_delta"
|
||||
AGENT_TOOL_FINAL = "agent_tool_final"
|
||||
|
||||
|
||||
class BaseObj(BaseModel):
|
||||
@@ -210,31 +211,46 @@ class CustomToolStart(BaseObj):
|
||||
tool_name: str
|
||||
|
||||
|
||||
# The allowed streamed packets for a custom tool
|
||||
class CustomToolDelta(BaseObj):
|
||||
type: Literal["custom_tool_delta"] = StreamingType.CUSTOM_TOOL_DELTA.value
|
||||
|
||||
tool_name: str
|
||||
response_type: str
|
||||
# For non-file responses
|
||||
data: dict | list | str | int | float | bool | None = None
|
||||
# For file-based responses like image/csv
|
||||
file_ids: list[str] | None = None
|
||||
|
||||
|
||||
class AgentToolStart(BaseObj):
|
||||
type: Literal["agent_tool_start"] = StreamingType.AGENT_TOOL_START.value
|
||||
|
||||
agent_name: str
|
||||
agent_id: int
|
||||
|
||||
|
||||
class AgentToolDelta(BaseObj):
|
||||
type: Literal["agent_tool_delta"] = StreamingType.AGENT_TOOL_DELTA.value
|
||||
|
||||
agent_name: str
|
||||
status_text: str | None = None
|
||||
nested_content: str | None = None
|
||||
|
||||
|
||||
class AgentToolFinal(BaseObj):
|
||||
type: Literal["agent_tool_final"] = StreamingType.AGENT_TOOL_FINAL.value
|
||||
|
||||
agent_name: str
|
||||
summary: str
|
||||
full_response: str | None = None
|
||||
|
||||
|
||||
"""Packet"""
|
||||
|
||||
# Discriminated union of all possible packet object types
|
||||
PacketObj = Union[
|
||||
# Agent Response Packets
|
||||
AgentResponseStart,
|
||||
AgentResponseDelta,
|
||||
# Control Packets
|
||||
OverallStop,
|
||||
SectionEnd,
|
||||
# Error Packets
|
||||
PacketException,
|
||||
# Tool Packets
|
||||
SearchToolStart,
|
||||
SearchToolQueriesDelta,
|
||||
SearchToolDocumentsDelta,
|
||||
@@ -248,11 +264,12 @@ PacketObj = Union[
|
||||
PythonToolDelta,
|
||||
CustomToolStart,
|
||||
CustomToolDelta,
|
||||
# Reasoning Packets
|
||||
AgentToolStart,
|
||||
AgentToolDelta,
|
||||
AgentToolFinal,
|
||||
ReasoningStart,
|
||||
ReasoningDelta,
|
||||
ReasoningDone,
|
||||
# Citation Packets
|
||||
CitationInfo,
|
||||
]
|
||||
|
||||
|
||||
@@ -163,6 +163,7 @@ class ToolCallInfo(BaseModel):
|
||||
tool_call_response: str
|
||||
search_docs: list[SearchDoc] | None = None
|
||||
generated_images: list[GeneratedImage] | None = None
|
||||
invoked_persona_id: int | None = None # For agent tools
|
||||
|
||||
|
||||
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
|
||||
|
||||
@@ -30,6 +30,9 @@ from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from onyx.tools.models import DynamicSchemaInfo
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.agent.agent_tool import AgentTool
|
||||
from onyx.tools.tool_implementations.agent.agent_tool import generate_agent_tool_id
|
||||
from onyx.tools.tool_implementations.agent.models import AgentInvocationConfig
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
@@ -371,6 +374,52 @@ def construct_tools(
|
||||
f"Tool '{expected_tool_name}' not found in MCP server '{mcp_server.name}'"
|
||||
)
|
||||
|
||||
if hasattr(persona, "child_personas") and persona.child_personas:
|
||||
from onyx.db.models import Persona__Persona
|
||||
from sqlalchemy import select
|
||||
|
||||
for child_persona in persona.child_personas:
|
||||
if child_persona.deleted:
|
||||
continue
|
||||
|
||||
config_stmt = select(Persona__Persona).where(
|
||||
Persona__Persona.parent_persona_id == persona.id,
|
||||
Persona__Persona.child_persona_id == child_persona.id,
|
||||
)
|
||||
config_row = db_session.execute(config_stmt).scalar_one_or_none()
|
||||
|
||||
agent_config = AgentInvocationConfig(
|
||||
pass_conversation_context=(
|
||||
config_row.pass_conversation_context if config_row else True
|
||||
),
|
||||
pass_files=config_row.pass_files if config_row else False,
|
||||
max_tokens_to_child=(
|
||||
config_row.max_tokens_to_child if config_row else None
|
||||
),
|
||||
max_tokens_from_child=(
|
||||
config_row.max_tokens_from_child if config_row else None
|
||||
),
|
||||
invocation_instructions=(
|
||||
config_row.invocation_instructions if config_row else None
|
||||
),
|
||||
)
|
||||
|
||||
agent_tool_id = generate_agent_tool_id(persona.id, child_persona.id)
|
||||
|
||||
agent_tool = AgentTool(
|
||||
tool_id=agent_tool_id,
|
||||
emitter=emitter,
|
||||
child_persona=child_persona,
|
||||
parent_persona_id=persona.id,
|
||||
agent_config=agent_config,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
)
|
||||
|
||||
tool_dict[agent_tool_id] = [agent_tool]
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from onyx.tools.tool_implementations.agent.agent_tool import AgentTool
|
||||
|
||||
__all__ = ["AgentTool"]
|
||||
571
backend/onyx/tools/tool_implementations/agent/agent_tool.py
Normal file
571
backend/onyx/tools/tool_implementations/agent/agent_tool.py
Normal file
@@ -0,0 +1,571 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.configs.chat_configs import MAX_AGENT_RECURSION_DEPTH
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.server.query_and_chat.streaming_models import AgentToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentToolFinal
|
||||
from onyx.server.query_and_chat.streaming_models import AgentToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.agent.models import AgentInvocationConfig
|
||||
from onyx.tools.tool_implementations.agent.models import AgentToolOverrideKwargs
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
AGENT_TOOL_ID_OFFSET = 1000000
|
||||
|
||||
|
||||
def generate_agent_tool_id(parent_persona_id: int, child_persona_id: int) -> int:
|
||||
return AGENT_TOOL_ID_OFFSET + (parent_persona_id * 10000) + child_persona_id
|
||||
|
||||
|
||||
class AgentTool(Tool[AgentToolOverrideKwargs]):
|
||||
NAME_PREFIX = "invoke_agent_"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_id: int,
|
||||
emitter: Emitter,
|
||||
child_persona: Persona,
|
||||
parent_persona_id: int,
|
||||
agent_config: AgentInvocationConfig,
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
fast_llm: LLM,
|
||||
):
|
||||
super().__init__(emitter)
|
||||
self._tool_id = tool_id
|
||||
self._child_persona = child_persona
|
||||
self._parent_persona_id = parent_persona_id
|
||||
self._config = agent_config
|
||||
self._db_session = db_session
|
||||
self._user = user
|
||||
self._llm = llm
|
||||
self._fast_llm = fast_llm
|
||||
self._collected_search_queries: list[str] = []
|
||||
self._collected_search_docs: list = []
|
||||
self._nested_agent_runs: list[dict] = []
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._tool_id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
safe_name = self._child_persona.name.lower().replace(" ", "_").replace("-", "_")
|
||||
return f"{self.NAME_PREFIX}{safe_name}"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
base_desc = f"Invoke the '{self._child_persona.name}' agent to help with tasks."
|
||||
if self._child_persona.description:
|
||||
base_desc += (
|
||||
f" This agent specializes in: {self._child_persona.description}"
|
||||
)
|
||||
if self._config.invocation_instructions:
|
||||
base_desc += f" Use this agent when: {self._config.invocation_instructions}"
|
||||
return base_desc
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return f"Agent: {self._child_persona.name}"
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "The specific task or question to delegate to this agent",
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": "Additional context or background information for the agent (optional)",
|
||||
},
|
||||
},
|
||||
"required": ["task"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def emit_start(self, turn_index: int) -> None:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentToolStart(
|
||||
agent_name=self._child_persona.name,
|
||||
agent_id=self._child_persona.id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def _get_tool_override_kwargs(self, tool: Tool, task: str) -> Any:
|
||||
from onyx.tools.models import (
|
||||
ChatMinimalTextMessage,
|
||||
OpenURLToolOverrideKwargs,
|
||||
SearchToolOverrideKwargs,
|
||||
WebSearchToolOverrideKwargs,
|
||||
)
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
)
|
||||
from onyx.configs.constants import MessageType
|
||||
|
||||
if isinstance(tool, OpenURLTool):
|
||||
return OpenURLToolOverrideKwargs(
|
||||
starting_citation_num=1,
|
||||
citation_mapping={},
|
||||
)
|
||||
elif isinstance(tool, SearchTool):
|
||||
minimal_history = [
|
||||
ChatMinimalTextMessage(
|
||||
message=task,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
]
|
||||
return SearchToolOverrideKwargs(
|
||||
starting_citation_num=1,
|
||||
original_query=task,
|
||||
message_history=minimal_history,
|
||||
)
|
||||
elif isinstance(tool, WebSearchTool):
|
||||
return WebSearchToolOverrideKwargs(
|
||||
starting_citation_num=1,
|
||||
)
|
||||
return None
|
||||
|
||||
def _build_child_tools(self) -> list[Tool]:
|
||||
"""Build the child agent's tools (excluding AgentTools to prevent infinite recursion)."""
|
||||
from onyx.tools.tool_constructor import construct_tools, SearchToolConfig
|
||||
|
||||
tool_dict = construct_tools(
|
||||
persona=self._child_persona,
|
||||
db_session=self._db_session,
|
||||
emitter=self.emitter,
|
||||
user=self._user,
|
||||
llm=self._llm,
|
||||
fast_llm=self._fast_llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
disable_internal_search=False,
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
for tool in tool_list:
|
||||
tools.append(tool)
|
||||
return tools
|
||||
|
||||
def _run_child_agent_loop(
|
||||
self,
|
||||
task: str,
|
||||
context: str | None,
|
||||
turn_index: int,
|
||||
max_iterations: int = 5,
|
||||
) -> str:
|
||||
"""Run a simplified agent loop for the child agent with its tools."""
|
||||
from onyx.llm.message_types import (
|
||||
AssistantMessage,
|
||||
FunctionCall,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
UserMessageWithText,
|
||||
)
|
||||
|
||||
child_tools = self._build_child_tools()
|
||||
tool_definitions = [tool.tool_definition() for tool in child_tools]
|
||||
tool_name_to_tool = {tool.name: tool for tool in child_tools}
|
||||
|
||||
base_system_prompt = (
|
||||
self._child_persona.system_prompt or "You are a helpful assistant."
|
||||
)
|
||||
|
||||
if child_tools:
|
||||
tool_names = [t.display_name for t in child_tools]
|
||||
tools_instruction = (
|
||||
f"\n\nYou have access to the following tools: {', '.join(tool_names)}. "
|
||||
"CRITICAL: You MUST use your tools to look up specific, accurate information before answering. "
|
||||
"Do NOT make up information, use placeholder text like '[specific value]', or provide generic responses. "
|
||||
"Always use your tools first to gather real data, then provide a response based on that data."
|
||||
)
|
||||
system_prompt = base_system_prompt + tools_instruction
|
||||
else:
|
||||
system_prompt = base_system_prompt
|
||||
|
||||
full_prompt = f"Task: {task}"
|
||||
if context:
|
||||
full_prompt = f"Context: {context}\n\n{full_prompt}"
|
||||
|
||||
if child_tools:
|
||||
full_prompt += (
|
||||
"\n\nREMINDER: Use your available tools to find real, specific information. "
|
||||
"Do not provide a response with placeholders or generic text."
|
||||
)
|
||||
|
||||
messages: list = [
|
||||
SystemMessage(role="system", content=system_prompt),
|
||||
UserMessageWithText(role="user", content=full_prompt),
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Running child agent '{self._child_persona.name}' with {len(child_tools)} tools: "
|
||||
f"{[t.name for t in child_tools]}"
|
||||
)
|
||||
|
||||
final_response = ""
|
||||
collected_search_queries: list[str] = []
|
||||
collected_search_docs: list = []
|
||||
self._nested_agent_runs = []
|
||||
|
||||
for iteration in range(max_iterations):
|
||||
logger.debug(f"Child agent iteration {iteration + 1}/{max_iterations}")
|
||||
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentToolDelta(
|
||||
agent_name=self._child_persona.name,
|
||||
status_text=f"Thinking... (step {iteration + 1})",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if tool_definitions:
|
||||
current_tool_choice = "required" if iteration == 0 else "auto"
|
||||
logger.debug(
|
||||
f"Child agent iteration {iteration + 1}: tool_choice={current_tool_choice}"
|
||||
)
|
||||
response = self._llm.invoke(
|
||||
prompt=messages,
|
||||
tools=tool_definitions,
|
||||
tool_choice=current_tool_choice,
|
||||
)
|
||||
else:
|
||||
response = self._llm.invoke(prompt=messages)
|
||||
|
||||
assistant_content = response.choice.message.content
|
||||
tool_calls = response.choice.message.tool_calls
|
||||
|
||||
if tool_calls:
|
||||
tool_call_entries: list[ToolCall] = []
|
||||
for tc in tool_calls:
|
||||
fn_name: str = tc.function.name if tc.function.name else ""
|
||||
fn_args: str = (
|
||||
tc.function.arguments if tc.function.arguments else "{}"
|
||||
)
|
||||
fn_call: FunctionCall = {"name": fn_name, "arguments": fn_args}
|
||||
tc_entry: ToolCall = {
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": fn_call,
|
||||
}
|
||||
tool_call_entries.append(tc_entry)
|
||||
|
||||
assistant_msg: AssistantMessage = cast(
|
||||
AssistantMessage,
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": assistant_content,
|
||||
"tool_calls": tool_call_entries,
|
||||
},
|
||||
)
|
||||
messages.append(assistant_msg)
|
||||
|
||||
for tool_call in tool_calls:
|
||||
tool_name: str = (
|
||||
tool_call.function.name if tool_call.function.name else ""
|
||||
)
|
||||
tool = tool_name_to_tool.get(tool_name)
|
||||
|
||||
if not tool:
|
||||
logger.warning(
|
||||
f"Tool {tool_name} not found for child agent. "
|
||||
f"Available tools: {list(tool_name_to_tool.keys())}"
|
||||
)
|
||||
tool_result = f"Error: Tool {tool_name} not found"
|
||||
else:
|
||||
logger.info(
|
||||
f"Child agent '{self._child_persona.name}' calling tool: {tool_name}"
|
||||
)
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentToolDelta(
|
||||
agent_name=self._child_persona.name,
|
||||
status_text=f"Using {tool.display_name}...",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
fn_arguments: str = (
|
||||
tool_call.function.arguments
|
||||
if tool_call.function.arguments
|
||||
else "{}"
|
||||
)
|
||||
args = json.loads(fn_arguments)
|
||||
logger.info(f"Tool {tool_name} called with args: {args}")
|
||||
|
||||
# Emit the tool start packet (normally done by tool_runner)
|
||||
tool.emit_start(turn_index=turn_index)
|
||||
|
||||
tool_override_kwargs = self._get_tool_override_kwargs(
|
||||
tool, task
|
||||
)
|
||||
tool_response = tool.run(
|
||||
turn_index=turn_index,
|
||||
override_kwargs=tool_override_kwargs,
|
||||
**args,
|
||||
)
|
||||
tool_result = tool_response.llm_facing_response
|
||||
|
||||
# If this is a nested agent tool, capture its run (response + search data + deeper agents)
|
||||
if isinstance(tool, AgentTool):
|
||||
nested_queries = (
|
||||
getattr(tool, "_collected_search_queries", []) or []
|
||||
)
|
||||
nested_docs = (
|
||||
getattr(tool, "_collected_search_docs", []) or []
|
||||
)
|
||||
nested_doc_dicts = []
|
||||
for doc in nested_docs:
|
||||
try:
|
||||
if hasattr(doc, "db_doc_id"):
|
||||
doc_dump = doc.model_dump()
|
||||
if doc_dump.get("db_doc_id") is None:
|
||||
doc_dump["db_doc_id"] = 0
|
||||
else:
|
||||
doc_dump = SavedSearchDoc.from_search_doc(
|
||||
doc, db_doc_id=0
|
||||
).model_dump()
|
||||
nested_doc_dicts.append(doc_dump)
|
||||
except Exception:
|
||||
continue
|
||||
nested_entry: dict = {
|
||||
"agent_name": tool._child_persona.name,
|
||||
"agent_id": tool._child_persona.id,
|
||||
"response": tool_result,
|
||||
"search_queries": nested_queries,
|
||||
"search_docs": nested_doc_dicts,
|
||||
}
|
||||
if getattr(tool, "_nested_agent_runs", None):
|
||||
nested_entry["nested_runs"] = (
|
||||
tool._nested_agent_runs
|
||||
)
|
||||
self._nested_agent_runs.append(nested_entry)
|
||||
|
||||
# Bubble up nested agent search data
|
||||
if nested_queries:
|
||||
existing_queries = set(collected_search_queries)
|
||||
for query in nested_queries:
|
||||
if query not in existing_queries:
|
||||
collected_search_queries.append(query)
|
||||
existing_queries.add(query)
|
||||
if nested_docs:
|
||||
existing_doc_ids = {
|
||||
doc.document_id for doc in collected_search_docs
|
||||
}
|
||||
for doc in nested_docs:
|
||||
if doc.document_id not in existing_doc_ids:
|
||||
collected_search_docs.append(doc)
|
||||
existing_doc_ids.add(doc.document_id)
|
||||
|
||||
# Also collect from search tools directly
|
||||
from onyx.tools.models import SearchDocsResponse
|
||||
|
||||
if isinstance(
|
||||
tool_response.rich_response, SearchDocsResponse
|
||||
):
|
||||
if tool_response.rich_response.search_docs:
|
||||
existing_doc_ids = {
|
||||
doc.document_id for doc in collected_search_docs
|
||||
}
|
||||
for doc in tool_response.rich_response.search_docs:
|
||||
if doc.document_id not in existing_doc_ids:
|
||||
collected_search_docs.append(doc)
|
||||
existing_doc_ids.add(doc.document_id)
|
||||
# Extract queries from args if it's a search tool
|
||||
if "queries" in args:
|
||||
queries = args["queries"]
|
||||
if isinstance(queries, list):
|
||||
existing_queries = set(collected_search_queries)
|
||||
for query in queries:
|
||||
if query not in existing_queries:
|
||||
collected_search_queries.append(query)
|
||||
existing_queries.add(query)
|
||||
|
||||
logger.info(
|
||||
f"Tool {tool_name} returned {len(tool_result)} chars: "
|
||||
f"{tool_result[:200]}..."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error running tool {tool_name}")
|
||||
tool_result = f"Error running tool: {str(e)}"
|
||||
|
||||
tool_msg: ToolMessage = {
|
||||
"role": "tool",
|
||||
"content": tool_result,
|
||||
"tool_call_id": tool_call.id,
|
||||
}
|
||||
messages.append(tool_msg)
|
||||
|
||||
else:
|
||||
if assistant_content:
|
||||
final_response = assistant_content
|
||||
logger.info(
|
||||
f"Child agent '{self._child_persona.name}' final response "
|
||||
f"({len(assistant_content)} chars): {assistant_content[:300]}..."
|
||||
)
|
||||
for chunk in [
|
||||
assistant_content[i : i + 50]
|
||||
for i in range(0, len(assistant_content), 50)
|
||||
]:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentToolDelta(
|
||||
agent_name=self._child_persona.name,
|
||||
nested_content=chunk,
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Child agent '{self._child_persona.name}' produced no response content"
|
||||
)
|
||||
break
|
||||
|
||||
if not final_response:
|
||||
logger.warning(
|
||||
f"Child agent '{self._child_persona.name}' loop ended without final response "
|
||||
f"after {max_iterations} iterations"
|
||||
)
|
||||
|
||||
# Store collected data for persistence
|
||||
self._collected_search_queries = collected_search_queries
|
||||
self._collected_search_docs = collected_search_docs
|
||||
|
||||
return final_response
|
||||
|
||||
def run(
|
||||
self,
|
||||
turn_index: int,
|
||||
override_kwargs: AgentToolOverrideKwargs,
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
task = llm_kwargs.get("task", "")
|
||||
context = llm_kwargs.get("context", "")
|
||||
|
||||
current_depth = override_kwargs.current_depth if override_kwargs else 0
|
||||
parent_ids = override_kwargs.parent_agent_ids if override_kwargs else []
|
||||
|
||||
if current_depth >= MAX_AGENT_RECURSION_DEPTH:
|
||||
error_msg = (
|
||||
f"Maximum agent recursion depth ({MAX_AGENT_RECURSION_DEPTH}) reached."
|
||||
)
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentToolFinal(
|
||||
agent_name=self._child_persona.name,
|
||||
summary=error_msg,
|
||||
full_response=error_msg,
|
||||
),
|
||||
)
|
||||
)
|
||||
return ToolResponse(rich_response=None, llm_facing_response=error_msg)
|
||||
|
||||
if self._child_persona.id in parent_ids:
|
||||
error_msg = f"Circular agent invocation detected. Agent '{self._child_persona.name}' is already in the call chain."
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentToolFinal(
|
||||
agent_name=self._child_persona.name,
|
||||
summary=error_msg,
|
||||
full_response=error_msg,
|
||||
),
|
||||
)
|
||||
)
|
||||
return ToolResponse(rich_response=None, llm_facing_response=error_msg)
|
||||
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentToolDelta(
|
||||
agent_name=self._child_persona.name,
|
||||
status_text=f"Processing task: {task[:100]}...",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
response_text = self._run_child_agent_loop(
|
||||
task=task,
|
||||
context=context,
|
||||
turn_index=turn_index,
|
||||
)
|
||||
|
||||
if not response_text:
|
||||
response_text = "The agent did not produce a response."
|
||||
|
||||
if self._config.max_tokens_from_child:
|
||||
max_chars = self._config.max_tokens_from_child * 4
|
||||
if len(response_text) > max_chars:
|
||||
response_text = response_text[:max_chars] + "... [truncated]"
|
||||
|
||||
summary = (
|
||||
response_text[:200] + "..."
|
||||
if len(response_text) > 200
|
||||
else response_text
|
||||
)
|
||||
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentToolFinal(
|
||||
agent_name=self._child_persona.name,
|
||||
summary=summary,
|
||||
full_response=response_text,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
rich_response=None,
|
||||
llm_facing_response=f"Response from {self._child_persona.name}:\n{response_text}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error invoking agent {self._child_persona.name}")
|
||||
error_msg = f"Error invoking agent: {str(e)}"
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentToolFinal(
|
||||
agent_name=self._child_persona.name,
|
||||
summary=error_msg,
|
||||
full_response=error_msg,
|
||||
),
|
||||
)
|
||||
)
|
||||
return ToolResponse(rich_response=None, llm_facing_response=error_msg)
|
||||
15
backend/onyx/tools/tool_implementations/agent/models.py
Normal file
15
backend/onyx/tools/tool_implementations/agent/models.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AgentToolOverrideKwargs(BaseModel):
|
||||
current_depth: int = 0
|
||||
parent_agent_ids: list[int] = []
|
||||
original_query: str = ""
|
||||
|
||||
|
||||
class AgentInvocationConfig(BaseModel):
|
||||
pass_conversation_context: bool = True
|
||||
pass_files: bool = False
|
||||
max_tokens_to_child: int | None = None
|
||||
max_tokens_from_child: int | None = None
|
||||
invocation_instructions: str | None = None
|
||||
6
backend/package-lock.json
generated
Normal file
6
backend/package-lock.json
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"name": "backend",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {}
|
||||
}
|
||||
112
web/package-lock.json
generated
112
web/package-lock.json
generated
@@ -1733,6 +1733,118 @@
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-x64": {
|
||||
"version": "16.0.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.0.7.tgz",
|
||||
"integrity": "sha512-rtZ7BhnVvO1ICf3QzfW9H3aPz7GhBrnSIMZyr4Qy6boXF0b5E3QLs+cvJmg3PsTCG2M1PBoC+DANUi4wCOKXpA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-gnu": {
|
||||
"version": "16.0.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.0.7.tgz",
|
||||
"integrity": "sha512-mloD5WcPIeIeeZqAIP5c2kdaTa6StwP4/2EGy1mUw8HiexSHGK/jcM7lFuS3u3i2zn+xH9+wXJs6njO7VrAqww==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-musl": {
|
||||
"version": "16.0.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.0.7.tgz",
|
||||
"integrity": "sha512-+ksWNrZrthisXuo9gd1XnjHRowCbMtl/YgMpbRvFeDEqEBd523YHPWpBuDjomod88U8Xliw5DHhekBC3EOOd9g==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-gnu": {
|
||||
"version": "16.0.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.0.7.tgz",
|
||||
"integrity": "sha512-4WtJU5cRDxpEE44Ana2Xro1284hnyVpBb62lIpU5k85D8xXxatT+rXxBgPkc7C1XwkZMWpK5rXLXTh9PFipWsA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-musl": {
|
||||
"version": "16.0.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.0.7.tgz",
|
||||
"integrity": "sha512-HYlhqIP6kBPXalW2dbMTSuB4+8fe+j9juyxwfMwCe9kQPPeiyFn7NMjNfoFOfJ2eXkeQsoUGXg+O2SE3m4Qg2w==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-arm64-msvc": {
|
||||
"version": "16.0.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.0.7.tgz",
|
||||
"integrity": "sha512-EviG+43iOoBRZg9deGauXExjRphhuYmIOJ12b9sAPy0eQ6iwcPxfED2asb/s2/yiLYOdm37kPaiZu8uXSYPs0Q==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-x64-msvc": {
|
||||
"version": "16.0.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.0.7.tgz",
|
||||
"integrity": "sha512-gniPjy55zp5Eg0896qSrf3yB1dw4F/3s8VK1ephdsZZ129j2n6e1WqCbE2YgcKhW9hPB9TVZENugquWJD5x0ug==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
},
|
||||
"node_modules/@nodelib/fs.scandir": {
|
||||
"version": "2.1.5",
|
||||
"license": "MIT",
|
||||
|
||||
@@ -148,9 +148,7 @@ export default function AssistantEditor({
|
||||
tools,
|
||||
shouldAddAssistantToUserPreferences,
|
||||
}: AssistantEditorProps) {
|
||||
// NOTE: assistants = agents
|
||||
// TODO: rename everything to agents
|
||||
const { refresh: refreshAgents } = useAgents();
|
||||
const { agents: availableAgents, refresh: refreshAgents } = useAgents();
|
||||
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
@@ -274,6 +272,16 @@ export default function AssistantEditor({
|
||||
? "user_files"
|
||||
: "team_knowledge",
|
||||
is_default_persona: existingPersona?.is_default_persona ?? false,
|
||||
child_persona_ids:
|
||||
existingPersona?.child_persona_configs?.map((c) => c.persona_id) ?? [],
|
||||
child_persona_configs:
|
||||
existingPersona?.child_persona_configs?.reduce(
|
||||
(acc, config) => {
|
||||
acc[config.persona_id] = config;
|
||||
return acc;
|
||||
},
|
||||
{} as Record<number, any>
|
||||
) ?? {},
|
||||
};
|
||||
|
||||
interface AssistantPrompt {
|
||||
@@ -478,6 +486,8 @@ export default function AssistantEditor({
|
||||
selectedGroups: Yup.array().of(Yup.number()),
|
||||
knowledge_source: Yup.string().required(),
|
||||
is_default_persona: Yup.boolean().required(),
|
||||
child_persona_ids: Yup.array().of(Yup.number()),
|
||||
child_persona_configs: Yup.object(),
|
||||
})
|
||||
.test(
|
||||
"system-prompt-or-task-prompt",
|
||||
@@ -553,6 +563,18 @@ export default function AssistantEditor({
|
||||
const groups = values.is_public ? [] : values.selectedGroups;
|
||||
const teamKnowledge = values.knowledge_source === "team_knowledge";
|
||||
|
||||
const configMap =
|
||||
(values.child_persona_configs as Record<number, any>) || {};
|
||||
const childPersonaConfigs = values.child_persona_ids?.map(
|
||||
(id: number) => ({
|
||||
persona_id: id,
|
||||
...(configMap[id] || {
|
||||
pass_conversation_context: true,
|
||||
pass_files: false,
|
||||
}),
|
||||
})
|
||||
);
|
||||
|
||||
const submissionData: PersonaUpsertParameters = {
|
||||
...values,
|
||||
starter_messages: starterMessages,
|
||||
@@ -571,6 +593,8 @@ export default function AssistantEditor({
|
||||
num_chunks: numChunks,
|
||||
document_set_ids: teamKnowledge ? values.document_set_ids : [],
|
||||
user_file_ids: teamKnowledge ? [] : values.user_file_ids,
|
||||
child_persona_ids: values.child_persona_ids || [],
|
||||
child_persona_configs: childPersonaConfigs,
|
||||
};
|
||||
|
||||
let personaResponse;
|
||||
@@ -1190,6 +1214,8 @@ export default function AssistantEditor({
|
||||
: "Image Generation requires an OpenAI or Azure Dall-E configuration."
|
||||
}
|
||||
hideSearchTool={true}
|
||||
availableAgents={availableAgents}
|
||||
currentPersonaId={existingPersona?.id}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -9,6 +9,23 @@ export interface StarterMessage extends StarterMessageBase {
|
||||
name: string;
|
||||
}
|
||||
|
||||
export interface ChildPersonaSnapshot {
|
||||
id: number;
|
||||
name: string;
|
||||
description: string;
|
||||
uploaded_image_id?: string;
|
||||
icon_name?: string;
|
||||
}
|
||||
|
||||
export interface ChildPersonaConfig {
|
||||
persona_id: number;
|
||||
pass_conversation_context: boolean;
|
||||
pass_files: boolean;
|
||||
max_tokens_to_child?: number | null;
|
||||
max_tokens_from_child?: number | null;
|
||||
invocation_instructions?: string | null;
|
||||
}
|
||||
|
||||
export interface MinimalPersonaSnapshot {
|
||||
id: number;
|
||||
name: string;
|
||||
@@ -37,8 +54,9 @@ export interface Persona extends MinimalPersonaSnapshot {
|
||||
users: MinimalUserSnapshot[];
|
||||
groups: number[];
|
||||
num_chunks?: number;
|
||||
child_personas?: ChildPersonaSnapshot[];
|
||||
child_persona_configs?: ChildPersonaConfig[];
|
||||
|
||||
// Embedded prompt fields on persona
|
||||
system_prompt: string | null;
|
||||
task_prompt: string | null;
|
||||
datetime_aware: boolean;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import {
|
||||
ChildPersonaConfig,
|
||||
MinimalPersonaSnapshot,
|
||||
Persona,
|
||||
StarterMessage,
|
||||
@@ -30,6 +31,8 @@ interface PersonaUpsertRequest {
|
||||
display_priority: number | null;
|
||||
label_ids: number[] | null;
|
||||
user_file_ids: string[] | null;
|
||||
child_persona_ids: number[];
|
||||
child_persona_configs?: ChildPersonaConfig[] | null;
|
||||
}
|
||||
|
||||
export interface PersonaUpsertParameters {
|
||||
@@ -54,6 +57,8 @@ export interface PersonaUpsertParameters {
|
||||
is_default_persona: boolean;
|
||||
label_ids: number[] | null;
|
||||
user_file_ids: string[];
|
||||
child_persona_ids: number[];
|
||||
child_persona_configs?: ChildPersonaConfig[] | null;
|
||||
}
|
||||
|
||||
function buildPersonaUpsertRequest(
|
||||
@@ -76,6 +81,8 @@ function buildPersonaUpsertRequest(
|
||||
remove_image,
|
||||
search_start_date,
|
||||
user_file_ids,
|
||||
child_persona_ids,
|
||||
child_persona_configs,
|
||||
} = creationRequest;
|
||||
|
||||
return {
|
||||
@@ -106,6 +113,8 @@ function buildPersonaUpsertRequest(
|
||||
display_priority: null,
|
||||
label_ids: creationRequest.label_ids ?? null,
|
||||
user_file_ids: user_file_ids ?? null,
|
||||
child_persona_ids: child_persona_ids ?? [],
|
||||
child_persona_configs: child_persona_configs ?? null,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -218,6 +218,7 @@ export default function AIMessage({
|
||||
PacketType.CUSTOM_TOOL_START,
|
||||
PacketType.FETCH_TOOL_START,
|
||||
PacketType.REASONING_START,
|
||||
PacketType.AGENT_TOOL_START,
|
||||
];
|
||||
return packets.some((packet) =>
|
||||
contentPacketTypes.includes(packet.obj.type as PacketType)
|
||||
|
||||
@@ -32,8 +32,17 @@ type DisplayItem = {
|
||||
packets: Packet[];
|
||||
};
|
||||
|
||||
// Helper to check if a tool group is an agent tool (which may contain nested search)
|
||||
function isAgentToolGroup(packets: Packet[]): boolean {
|
||||
return packets.some((p) => p.obj.type === PacketType.AGENT_TOOL_START);
|
||||
}
|
||||
|
||||
// Helper to check if a tool group is an internal search (not internet search)
|
||||
// Excludes agent tool groups since they handle nested searches internally
|
||||
function isInternalSearchToolGroup(packets: Packet[]): boolean {
|
||||
// If this is an agent tool group, don't treat it as a search tool group
|
||||
if (isAgentToolGroup(packets)) return false;
|
||||
|
||||
const hasSearchStart = packets.some(
|
||||
(p) => p.obj.type === PacketType.SEARCH_TOOL_START
|
||||
);
|
||||
|
||||
@@ -18,6 +18,7 @@ import { PythonToolRenderer } from "./renderers/PythonToolRenderer";
|
||||
import { ReasoningRenderer } from "./renderers/ReasoningRenderer";
|
||||
import CustomToolRenderer from "./renderers/CustomToolRenderer";
|
||||
import { FetchToolRenderer } from "./renderers/FetchToolRenderer";
|
||||
import { AgentToolRenderer } from "./renderers/AgentToolRenderer";
|
||||
|
||||
// Different types of chat packets using discriminated unions
|
||||
export interface GroupedPackets {
|
||||
@@ -52,6 +53,10 @@ function isFetchToolPacket(packet: Packet) {
|
||||
return packet.obj.type === PacketType.FETCH_TOOL_START;
|
||||
}
|
||||
|
||||
function isAgentToolPacket(packet: Packet) {
|
||||
return packet.obj.type === PacketType.AGENT_TOOL_START;
|
||||
}
|
||||
|
||||
function isReasoningPacket(packet: Packet): packet is ReasoningPacket {
|
||||
return (
|
||||
packet.obj.type === PacketType.REASONING_START ||
|
||||
@@ -66,6 +71,10 @@ export function findRenderer(
|
||||
if (groupedPackets.packets.some((packet) => isChatPacket(packet))) {
|
||||
return MessageTextRenderer;
|
||||
}
|
||||
// Agent tool check comes first - it handles nested search/fetch tool packets internally
|
||||
if (groupedPackets.packets.some((packet) => isAgentToolPacket(packet))) {
|
||||
return AgentToolRenderer;
|
||||
}
|
||||
if (groupedPackets.packets.some((packet) => isSearchToolPacket(packet))) {
|
||||
return SearchToolRenderer;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,368 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import {
|
||||
FiUsers,
|
||||
FiChevronDown,
|
||||
FiChevronRight,
|
||||
FiSearch,
|
||||
FiBookOpen,
|
||||
FiCheck,
|
||||
} from "react-icons/fi";
|
||||
import { RenderType, MessageRenderer } from "../interfaces";
|
||||
import { Packet, PacketType } from "@/app/chat/services/streamingModels";
|
||||
import { OnyxDocument } from "@/lib/search/interfaces";
|
||||
import { SourceChip2 } from "@/app/chat/components/SourceChip2";
|
||||
import { truncateString } from "@/lib/utils";
|
||||
import { ResultIcon } from "@/components/chat/sources/SourceCard";
|
||||
import { BlinkingDot } from "../../BlinkingDot";
|
||||
|
||||
const MAX_TITLE_LENGTH = 25;
|
||||
const INITIAL_QUERIES_TO_SHOW = 3;
|
||||
const QUERIES_PER_EXPANSION = 5;
|
||||
const INITIAL_RESULTS_TO_SHOW = 3;
|
||||
const RESULTS_PER_EXPANSION = 10;
|
||||
|
||||
interface AgentRunState {
|
||||
agentName: string;
|
||||
agentId: number;
|
||||
statusText: string;
|
||||
nestedContent: string;
|
||||
summary: string;
|
||||
fullResponse: string;
|
||||
isRunning: boolean;
|
||||
isComplete: boolean;
|
||||
nestedSearchQueries: string[];
|
||||
nestedSearchDocs: OnyxDocument[];
|
||||
hasNestedSearch: boolean;
|
||||
depth: number;
|
||||
}
|
||||
|
||||
function buildAgentRunsFromPackets(packets: Packet[]): AgentRunState[] {
|
||||
const runs: AgentRunState[] = [];
|
||||
const stack: AgentRunState[] = [];
|
||||
|
||||
for (const packet of packets) {
|
||||
const obj = packet.obj;
|
||||
|
||||
if (obj.type === PacketType.AGENT_TOOL_START) {
|
||||
const run: AgentRunState = {
|
||||
agentName: obj.agent_name || "",
|
||||
agentId: obj.agent_id || 0,
|
||||
statusText: "",
|
||||
nestedContent: "",
|
||||
summary: "",
|
||||
fullResponse: "",
|
||||
isRunning: true,
|
||||
isComplete: false,
|
||||
nestedSearchQueries: [],
|
||||
nestedSearchDocs: [],
|
||||
hasNestedSearch: false,
|
||||
depth: stack.length,
|
||||
};
|
||||
runs.push(run);
|
||||
stack.push(run);
|
||||
} else if (obj.type === PacketType.AGENT_TOOL_DELTA) {
|
||||
const run = stack[stack.length - 1];
|
||||
if (!run) continue;
|
||||
if (obj.status_text) {
|
||||
run.statusText = obj.status_text;
|
||||
}
|
||||
if (obj.nested_content) {
|
||||
run.nestedContent += obj.nested_content;
|
||||
}
|
||||
} else if (obj.type === PacketType.AGENT_TOOL_FINAL) {
|
||||
const run = stack[stack.length - 1];
|
||||
if (!run) continue;
|
||||
run.summary = obj.summary || "";
|
||||
run.fullResponse = obj.full_response || "";
|
||||
run.isRunning = false;
|
||||
run.isComplete = true;
|
||||
stack.pop();
|
||||
} else if (
|
||||
obj.type === PacketType.SEARCH_TOOL_START ||
|
||||
obj.type === PacketType.SEARCH_TOOL_QUERIES_DELTA ||
|
||||
obj.type === PacketType.SEARCH_TOOL_DOCUMENTS_DELTA
|
||||
) {
|
||||
const run = stack[stack.length - 1];
|
||||
if (!run) continue;
|
||||
if (obj.type === PacketType.SEARCH_TOOL_START) {
|
||||
run.hasNestedSearch = true;
|
||||
} else if (obj.type === PacketType.SEARCH_TOOL_QUERIES_DELTA) {
|
||||
run.nestedSearchQueries = obj.queries || [];
|
||||
} else if (obj.type === PacketType.SEARCH_TOOL_DOCUMENTS_DELTA) {
|
||||
run.nestedSearchDocs = obj.documents || [];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return runs;
|
||||
}
|
||||
|
||||
function AgentRunRenderer({
|
||||
run,
|
||||
renderType,
|
||||
depth,
|
||||
onComplete,
|
||||
children,
|
||||
}: {
|
||||
run: AgentRunState;
|
||||
renderType: RenderType;
|
||||
depth: number;
|
||||
onComplete: () => void;
|
||||
children: any;
|
||||
}) {
|
||||
const [isExpanded, setIsExpanded] = useState(false);
|
||||
const [queriesToShow, setQueriesToShow] = useState(INITIAL_QUERIES_TO_SHOW);
|
||||
const [resultsToShow, setResultsToShow] = useState(INITIAL_RESULTS_TO_SHOW);
|
||||
|
||||
useEffect(() => {
|
||||
if (run.isComplete) {
|
||||
onComplete();
|
||||
}
|
||||
}, [run.isComplete, onComplete]);
|
||||
|
||||
const status = useMemo(() => {
|
||||
if (run.isComplete) {
|
||||
return `Agent: ${run.agentName}`;
|
||||
}
|
||||
if (run.isRunning) {
|
||||
return `Agent: ${run.agentName}`;
|
||||
}
|
||||
return null;
|
||||
}, [run.agentName, run.isComplete, run.isRunning]);
|
||||
|
||||
const icon = FiUsers;
|
||||
|
||||
if (renderType === RenderType.HIGHLIGHT) {
|
||||
return children({
|
||||
icon,
|
||||
status: status,
|
||||
content: (
|
||||
<div
|
||||
className="flex flex-col gap-1 text-sm text-muted-foreground"
|
||||
style={{ marginLeft: depth ? depth * 12 : 0 }}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="font-medium text-foreground">
|
||||
Agent: {run.agentName}
|
||||
</span>
|
||||
{run.isRunning && (
|
||||
<span className="text-xs text-blue-500">working…</span>
|
||||
)}
|
||||
{run.isComplete && (
|
||||
<span className="text-xs text-green-600 dark:text-green-400 flex items-center gap-1">
|
||||
<FiCheck className="w-3 h-3" />
|
||||
completed
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{run.statusText && (
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{run.statusText}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
return children({
|
||||
icon,
|
||||
status,
|
||||
content: (
|
||||
<div
|
||||
className="flex flex-col mt-1.5"
|
||||
style={{ marginLeft: depth ? depth * 14 : 0 }}
|
||||
>
|
||||
{run.hasNestedSearch && (
|
||||
<div className="flex flex-col mt-3">
|
||||
<div className="flex items-center gap-2 mb-2 ml-1">
|
||||
<FiSearch className="w-3.5 h-3.5 text-gray-500" />
|
||||
<span className="text-sm text-gray-600 dark:text-gray-400">
|
||||
Searching internally
|
||||
</span>
|
||||
{run.nestedSearchQueries.length > 0 && (
|
||||
<span className="text-xs text-gray-500">
|
||||
({run.nestedSearchQueries.length} queries)
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-x-2 gap-y-2 ml-1">
|
||||
{run.nestedSearchQueries
|
||||
.slice(0, queriesToShow)
|
||||
.map((query, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className="text-xs animate-in fade-in slide-in-from-left-2 duration-150"
|
||||
style={{
|
||||
animationDelay: `${index * 30}ms`,
|
||||
animationFillMode: "backwards",
|
||||
}}
|
||||
>
|
||||
<SourceChip2
|
||||
icon={<FiSearch size={10} />}
|
||||
title={truncateString(query, MAX_TITLE_LENGTH)}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
{run.nestedSearchQueries.length > queriesToShow && (
|
||||
<div
|
||||
className="text-xs animate-in fade-in slide-in-from-left-2 duration-150"
|
||||
style={{
|
||||
animationDelay: `${queriesToShow * 30}ms`,
|
||||
animationFillMode: "backwards",
|
||||
}}
|
||||
>
|
||||
<SourceChip2
|
||||
title={`${
|
||||
run.nestedSearchQueries.length - queriesToShow
|
||||
} more...`}
|
||||
onClick={() => {
|
||||
setQueriesToShow((prev) =>
|
||||
Math.min(
|
||||
prev + QUERIES_PER_EXPANSION,
|
||||
run.nestedSearchQueries.length
|
||||
)
|
||||
);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{run.nestedSearchQueries.length === 0 && run.isRunning && (
|
||||
<BlinkingDot />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{run.nestedSearchDocs.length > 0 && (
|
||||
<div className="flex flex-col mt-3">
|
||||
<div className="flex items-center gap-2 mb-2 ml-1">
|
||||
<FiBookOpen className="w-3.5 h-3.5 text-gray-500" />
|
||||
<span className="text-sm text-gray-600 dark:text-gray-400">
|
||||
Reading documents
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-x-2 gap-y-2 ml-1">
|
||||
{run.nestedSearchDocs
|
||||
.slice(0, resultsToShow)
|
||||
.map((doc, index) => (
|
||||
<div
|
||||
key={doc.document_id}
|
||||
className="text-xs animate-in fade-in slide-in-from-left-2 duration-150"
|
||||
style={{
|
||||
animationDelay: `${index * 30}ms`,
|
||||
animationFillMode: "backwards",
|
||||
}}
|
||||
>
|
||||
<SourceChip2
|
||||
icon={<ResultIcon doc={doc} size={10} />}
|
||||
title={truncateString(
|
||||
doc.semantic_identifier || "",
|
||||
MAX_TITLE_LENGTH
|
||||
)}
|
||||
onClick={() => {
|
||||
if (doc.link) {
|
||||
window.open(doc.link, "_blank");
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
{run.nestedSearchDocs.length > resultsToShow && (
|
||||
<div
|
||||
className="text-xs animate-in fade-in slide-in-from-left-2 duration-150"
|
||||
style={{
|
||||
animationDelay: `${
|
||||
Math.min(resultsToShow, run.nestedSearchDocs.length) * 30
|
||||
}ms`,
|
||||
animationFillMode: "backwards",
|
||||
}}
|
||||
>
|
||||
<SourceChip2
|
||||
title={`${
|
||||
run.nestedSearchDocs.length - resultsToShow
|
||||
} more...`}
|
||||
onClick={() => {
|
||||
setResultsToShow((prev) =>
|
||||
Math.min(
|
||||
prev + RESULTS_PER_EXPANSION,
|
||||
run.nestedSearchDocs.length
|
||||
)
|
||||
);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{run.nestedSearchDocs.length === 0 &&
|
||||
run.nestedSearchQueries.length > 0 &&
|
||||
run.isRunning && <BlinkingDot />}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{run.isRunning && run.nestedContent && (
|
||||
<div className="mt-3 text-sm text-gray-700 dark:text-gray-300 whitespace-pre-wrap border-l-2 border-blue-200 dark:border-blue-800 pl-3 ml-1">
|
||||
{run.nestedContent}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{run.isComplete && run.fullResponse && (
|
||||
<div className="mt-3 ml-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setIsExpanded(!isExpanded)}
|
||||
className="flex items-center gap-1 text-xs text-gray-500 hover:text-gray-700 dark:hover:text-gray-300"
|
||||
>
|
||||
{isExpanded ? (
|
||||
<FiChevronDown className="w-3 h-3" />
|
||||
) : (
|
||||
<FiChevronRight className="w-3 h-3" />
|
||||
)}
|
||||
{isExpanded ? "Hide agent response" : "View agent response"}
|
||||
</button>
|
||||
{isExpanded && (
|
||||
<div className="mt-2 text-sm bg-gray-50 dark:bg-gray-800 p-3 rounded border max-h-96 overflow-y-auto whitespace-pre-wrap">
|
||||
{run.fullResponse}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{run.isRunning && !run.nestedContent && !run.hasNestedSearch && (
|
||||
<div className="ml-1">
|
||||
<BlinkingDot />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
export const AgentToolRenderer: MessageRenderer<Packet, {}> = ({
|
||||
packets,
|
||||
onComplete,
|
||||
renderType,
|
||||
children,
|
||||
}) => {
|
||||
const runs = buildAgentRunsFromPackets(packets);
|
||||
|
||||
return (
|
||||
<>
|
||||
{runs.map((run, idx) => (
|
||||
<AgentRunRenderer
|
||||
key={`${run.agentName}-${idx}`}
|
||||
run={run}
|
||||
renderType={renderType}
|
||||
depth={run.depth}
|
||||
onComplete={onComplete}
|
||||
>
|
||||
{children}
|
||||
</AgentRunRenderer>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default AgentToolRenderer;
|
||||
@@ -24,6 +24,9 @@ export function isToolPacket(
|
||||
PacketType.FETCH_TOOL_START,
|
||||
PacketType.FETCH_TOOL_URLS,
|
||||
PacketType.FETCH_TOOL_DOCUMENTS,
|
||||
PacketType.AGENT_TOOL_START,
|
||||
PacketType.AGENT_TOOL_DELTA,
|
||||
PacketType.AGENT_TOOL_FINAL,
|
||||
];
|
||||
if (includeSectionEnd) {
|
||||
toolPacketTypes.push(PacketType.SECTION_END);
|
||||
|
||||
@@ -29,6 +29,11 @@ export enum PacketType {
|
||||
CUSTOM_TOOL_START = "custom_tool_start",
|
||||
CUSTOM_TOOL_DELTA = "custom_tool_delta",
|
||||
|
||||
// Agent tool packets
|
||||
AGENT_TOOL_START = "agent_tool_start",
|
||||
AGENT_TOOL_DELTA = "agent_tool_delta",
|
||||
AGENT_TOOL_FINAL = "agent_tool_final",
|
||||
|
||||
// Reasoning packets
|
||||
REASONING_START = "reasoning_start",
|
||||
REASONING_DELTA = "reasoning_delta",
|
||||
@@ -143,6 +148,27 @@ export interface CustomToolDelta extends BaseObj {
|
||||
file_ids?: string[] | null;
|
||||
}
|
||||
|
||||
// Agent Tool Packets
|
||||
export interface AgentToolStart extends BaseObj {
|
||||
type: "agent_tool_start";
|
||||
agent_name: string;
|
||||
agent_id: number;
|
||||
}
|
||||
|
||||
export interface AgentToolDelta extends BaseObj {
|
||||
type: "agent_tool_delta";
|
||||
agent_name: string;
|
||||
status_text?: string;
|
||||
nested_content?: string;
|
||||
}
|
||||
|
||||
export interface AgentToolFinal extends BaseObj {
|
||||
type: "agent_tool_final";
|
||||
agent_name: string;
|
||||
summary: string;
|
||||
full_response?: string;
|
||||
}
|
||||
|
||||
// Reasoning Packets
|
||||
export interface ReasoningStart extends BaseObj {
|
||||
type: "reasoning_start";
|
||||
@@ -198,12 +224,18 @@ export type FetchToolObj =
|
||||
| FetchToolDocuments
|
||||
| SectionEnd;
|
||||
export type CustomToolObj = CustomToolStart | CustomToolDelta | SectionEnd;
|
||||
export type AgentToolObj =
|
||||
| AgentToolStart
|
||||
| AgentToolDelta
|
||||
| AgentToolFinal
|
||||
| SectionEnd;
|
||||
export type NewToolObj =
|
||||
| SearchToolObj
|
||||
| ImageGenerationToolObj
|
||||
| PythonToolObj
|
||||
| FetchToolObj
|
||||
| CustomToolObj;
|
||||
| CustomToolObj
|
||||
| AgentToolObj;
|
||||
|
||||
export type ReasoningObj = ReasoningStart | ReasoningDelta | SectionEnd;
|
||||
|
||||
@@ -269,6 +301,11 @@ export interface CustomToolPacket {
|
||||
obj: CustomToolObj;
|
||||
}
|
||||
|
||||
export interface AgentToolPacket {
|
||||
turn_index: number;
|
||||
obj: AgentToolObj;
|
||||
}
|
||||
|
||||
export interface ReasoningPacket {
|
||||
turn_index: number;
|
||||
obj: ReasoningObj;
|
||||
|
||||
220
web/src/components/admin/assistants/AgentSelector.tsx
Normal file
220
web/src/components/admin/assistants/AgentSelector.tsx
Normal file
@@ -0,0 +1,220 @@
|
||||
"use client";
|
||||
|
||||
import React, { memo, useState, useCallback, useMemo } from "react";
|
||||
import { useFormikContext } from "formik";
|
||||
import { Info } from "lucide-react";
|
||||
import { FiChevronDown, FiChevronRight, FiUsers, FiX } from "react-icons/fi";
|
||||
import { SearchMultiSelectDropdown } from "@/components/Dropdown";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Checkbox from "@/refresh-components/inputs/Checkbox";
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
|
||||
import { HoverPopup } from "@/components/HoverPopup";
|
||||
|
||||
interface AgentSelectorProps {
|
||||
availableAgents: MinimalPersonaSnapshot[];
|
||||
currentPersonaId?: number;
|
||||
}
|
||||
|
||||
export const AgentSelector = memo(function AgentSelector({
|
||||
availableAgents,
|
||||
currentPersonaId,
|
||||
}: AgentSelectorProps) {
|
||||
const { values, setFieldValue } = useFormikContext<any>();
|
||||
const [expandedAgentId, setExpandedAgentId] = useState<number | null>(null);
|
||||
|
||||
const filteredAgents = useMemo(
|
||||
() =>
|
||||
availableAgents.filter(
|
||||
(agent) =>
|
||||
agent.id !== currentPersonaId &&
|
||||
!agent.builtin_persona &&
|
||||
agent.is_visible
|
||||
),
|
||||
[availableAgents, currentPersonaId]
|
||||
);
|
||||
|
||||
const selectedAgentIds: number[] = values.child_persona_ids || [];
|
||||
|
||||
const selectedAgents = useMemo(
|
||||
() => filteredAgents.filter((agent) => selectedAgentIds.includes(agent.id)),
|
||||
[filteredAgents, selectedAgentIds]
|
||||
);
|
||||
|
||||
const handleSelect = useCallback(
|
||||
(option: { name: string; value: string | number }) => {
|
||||
const agentId =
|
||||
typeof option.value === "string"
|
||||
? parseInt(option.value, 10)
|
||||
: option.value;
|
||||
if (!selectedAgentIds.includes(agentId)) {
|
||||
setFieldValue("child_persona_ids", [...selectedAgentIds, agentId]);
|
||||
}
|
||||
},
|
||||
[selectedAgentIds, setFieldValue]
|
||||
);
|
||||
|
||||
const handleRemove = useCallback(
|
||||
(agentId: number) => {
|
||||
setFieldValue(
|
||||
"child_persona_ids",
|
||||
selectedAgentIds.filter((id) => id !== agentId)
|
||||
);
|
||||
const configMap = { ...(values.child_persona_configs || {}) };
|
||||
delete configMap[agentId];
|
||||
setFieldValue("child_persona_configs", configMap);
|
||||
},
|
||||
[selectedAgentIds, values.child_persona_configs, setFieldValue]
|
||||
);
|
||||
|
||||
const toggleExpand = useCallback((agentId: number) => {
|
||||
setExpandedAgentId((prev) => (prev === agentId ? null : agentId));
|
||||
}, []);
|
||||
|
||||
if (filteredAgents.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const dropdownOptions = filteredAgents
|
||||
.filter((agent) => !selectedAgentIds.includes(agent.id))
|
||||
.map((agent) => ({
|
||||
name: agent.name,
|
||||
value: agent.id,
|
||||
description: agent.description,
|
||||
}));
|
||||
|
||||
return (
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center gap-1.5">
|
||||
<Text mainUiBody text04>
|
||||
Agent Actions
|
||||
</Text>
|
||||
<HoverPopup
|
||||
mainContent={
|
||||
<Info className="h-3.5 w-3.5 text-text-400 cursor-help" />
|
||||
}
|
||||
popupContent={
|
||||
<div className="text-xs space-y-2 max-w-xs text-white">
|
||||
<div>
|
||||
<span className="font-semibold">
|
||||
Allow this agent to invoke other agents as tools
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
direction="bottom"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<SearchMultiSelectDropdown
|
||||
options={dropdownOptions}
|
||||
onSelect={handleSelect}
|
||||
/>
|
||||
|
||||
{selectedAgents.length > 0 && (
|
||||
<div className="flex flex-col gap-2 mt-2">
|
||||
{selectedAgents.map((agent) => {
|
||||
const isExpanded = expandedAgentId === agent.id;
|
||||
const config = values.child_persona_configs?.[agent.id] || {
|
||||
pass_conversation_context: true,
|
||||
pass_files: false,
|
||||
invocation_instructions: "",
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
key={agent.id}
|
||||
className="border rounded-lg p-3 dark:border-gray-700 bg-background-subtle"
|
||||
>
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => toggleExpand(agent.id)}
|
||||
className="p-1 hover:bg-gray-100 dark:hover:bg-gray-700 rounded transition-colors"
|
||||
>
|
||||
{isExpanded ? (
|
||||
<FiChevronDown className="w-4 h-4 text-gray-500" />
|
||||
) : (
|
||||
<FiChevronRight className="w-4 h-4 text-gray-500" />
|
||||
)}
|
||||
</button>
|
||||
<div className="w-6 h-6 rounded-full bg-blue-100 dark:bg-blue-900 flex items-center justify-center">
|
||||
<FiUsers className="w-3 h-3 text-blue-600 dark:text-blue-400" />
|
||||
</div>
|
||||
<div>
|
||||
<div className="font-medium text-sm">{agent.name}</div>
|
||||
<div className="text-xs text-gray-500 line-clamp-1 max-w-[200px]">
|
||||
{agent.description}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => handleRemove(agent.id)}
|
||||
className="p-1 hover:bg-gray-100 dark:hover:bg-gray-700 rounded transition-colors text-gray-500 hover:text-gray-700"
|
||||
>
|
||||
<FiX className="w-4 h-4" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{isExpanded && (
|
||||
<div className="mt-3 pt-3 border-t dark:border-gray-700 space-y-3 ml-8">
|
||||
<Text className="text-xs text-gray-500">
|
||||
Configuration for invoking this agent
|
||||
</Text>
|
||||
<label className="flex items-center gap-2 cursor-pointer">
|
||||
<Checkbox
|
||||
checked={config.pass_conversation_context}
|
||||
onCheckedChange={(checked) => {
|
||||
setFieldValue(`child_persona_configs.${agent.id}`, {
|
||||
...config,
|
||||
persona_id: agent.id,
|
||||
pass_conversation_context: checked,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
<span className="text-sm">Pass conversation context</span>
|
||||
</label>
|
||||
<label className="flex items-center gap-2 cursor-pointer">
|
||||
<Checkbox
|
||||
checked={config.pass_files}
|
||||
onCheckedChange={(checked) => {
|
||||
setFieldValue(`child_persona_configs.${agent.id}`, {
|
||||
...config,
|
||||
persona_id: agent.id,
|
||||
pass_files: checked,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
<span className="text-sm">Pass attached files</span>
|
||||
</label>
|
||||
<div>
|
||||
<label className="text-sm text-gray-600 dark:text-gray-400 mb-1 block">
|
||||
When to invoke (optional)
|
||||
</label>
|
||||
<textarea
|
||||
className="w-full text-sm border rounded p-2 dark:bg-gray-800 dark:border-gray-700"
|
||||
placeholder="e.g., Use this agent when the user asks about..."
|
||||
rows={2}
|
||||
value={config.invocation_instructions || ""}
|
||||
onChange={(e) => {
|
||||
setFieldValue(`child_persona_configs.${agent.id}`, {
|
||||
...config,
|
||||
persona_id: agent.id,
|
||||
invocation_instructions: e.target.value,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
export default AgentSelector;
|
||||
@@ -11,6 +11,7 @@ import { BooleanFormField } from "@/components/Field";
|
||||
import { ToolSnapshot, MCPServer } from "@/lib/tools/interfaces";
|
||||
import { MCPServerSection } from "./FormSections";
|
||||
import { MemoizedToolList } from "./MemoizedToolCheckboxes";
|
||||
import { AgentSelector } from "./AgentSelector";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import {
|
||||
SEARCH_TOOL_ID,
|
||||
@@ -21,6 +22,7 @@ import {
|
||||
} from "@/app/chat/components/tools/constants";
|
||||
import { HoverPopup } from "@/components/HoverPopup";
|
||||
import { Info } from "lucide-react";
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
|
||||
|
||||
interface ToolSelectorProps {
|
||||
tools: ToolSnapshot[];
|
||||
@@ -32,6 +34,8 @@ interface ToolSelectorProps {
|
||||
searchToolDisabled?: boolean;
|
||||
searchToolDisabledTooltip?: string;
|
||||
hideSearchTool?: boolean;
|
||||
availableAgents?: MinimalPersonaSnapshot[];
|
||||
currentPersonaId?: number;
|
||||
}
|
||||
|
||||
export function ToolSelector({
|
||||
@@ -44,6 +48,8 @@ export function ToolSelector({
|
||||
searchToolDisabled = false,
|
||||
searchToolDisabledTooltip,
|
||||
hideSearchTool = false,
|
||||
availableAgents = [],
|
||||
currentPersonaId,
|
||||
}: ToolSelectorProps) {
|
||||
const searchTool = tools.find((t) => t.in_code_tool_id === SEARCH_TOOL_ID);
|
||||
const webSearchTool = tools.find(
|
||||
@@ -157,7 +163,7 @@ export function ToolSelector({
|
||||
<Info className="h-3.5 w-3.5 text-text-400 cursor-help" />
|
||||
}
|
||||
popupContent={
|
||||
<div className="text-xs space-y-2 max-w-xs">
|
||||
<div className="text-xs space-y-2 max-w-xs text-white">
|
||||
<div>
|
||||
<span className="font-semibold">Internal Search:</span> Requires
|
||||
at least one connector to be configured to search your
|
||||
@@ -301,6 +307,13 @@ export function ToolSelector({
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
|
||||
{availableAgents.length > 0 && (
|
||||
<AgentSelector
|
||||
availableAgents={availableAgents}
|
||||
currentPersonaId={currentPersonaId}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user