Compare commits

...

6 Commits

Author SHA1 Message Date
Evan Lohn
f8605557e8 weird unicdoe 2026-02-09 09:58:54 -08:00
Evan Lohn
852eea4542 migration issues 2026-02-09 09:57:33 -08:00
Evan Lohn
76bf50ec5d more comments 2026-02-08 22:46:46 -08:00
Evan Lohn
be8009db1e typing 2026-02-08 22:43:07 -08:00
Evan Lohn
d78f2963c2 comments 2026-02-08 22:30:16 -08:00
Evan Lohn
54fae288c0 feat: file reader tool 2026-02-08 20:53:54 -08:00
7 changed files with 368 additions and 3 deletions

View File

@@ -0,0 +1,87 @@
"""add_file_reader_tool
Revision ID: d3fd499c829c
Revises: d56ffa94ca32
Create Date: 2026-02-07 19:28:22.452337
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d3fd499c829c"
down_revision = "d56ffa94ca32"
branch_labels = None
depends_on = None
FILE_READER_TOOL = {
"name": "read_file",
"display_name": "File Reader",
"description": (
"Read sections of user-uploaded files by character offset. "
"Useful for inspecting large files that cannot fit entirely in context."
),
"in_code_tool_id": "FileReaderTool",
"enabled": True,
}
def upgrade() -> None:
conn = op.get_bind()
# Check if tool already exists
existing = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": FILE_READER_TOOL["in_code_tool_id"]},
).fetchone()
if existing:
# Update existing tool
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
FILE_READER_TOOL,
)
else:
# Insert new tool
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
"""
),
FILE_READER_TOOL,
)
def downgrade() -> None:
conn = op.get_bind()
in_code_tool_id = FILE_READER_TOOL["in_code_tool_id"]
# Remove persona associations first (FK constraint)
conn.execute(
sa.text(
"""
DELETE FROM persona__tool
WHERE tool_id IN (
SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id
)
"""
),
{"in_code_tool_id": in_code_tool_id},
)
conn.execute(
sa.text("DELETE FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
{"in_code_tool_id": in_code_tool_id},
)

View File

@@ -86,6 +86,7 @@ from onyx.tools.interface import Tool
from onyx.tools.models import SearchToolUsage
from onyx.tools.tool_constructor import construct_tools
from onyx.tools.tool_constructor import CustomToolConfig
from onyx.tools.tool_constructor import FileReaderToolConfig
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
@@ -100,6 +101,39 @@ logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled"
def _collect_available_file_ids(
chat_history: list[ChatMessage],
project_id: int | None,
user_id: UUID | None,
db_session: Session,
) -> list[UUID]:
"""Collect all file IDs the FileReaderTool should be allowed to access.
Includes files attached to chat messages and files belonging to
the project (if any)."""
file_ids: set[UUID] = set()
for msg in chat_history:
if not msg.files:
continue
for fd in msg.files:
try:
file_ids.add(UUID(fd["id"]))
except (ValueError, KeyError):
pass
if project_id:
project_files = get_user_files_from_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
for uf in project_files:
file_ids.add(uf.id)
return list(file_ids)
def _should_enable_slack_search(
persona: Persona,
filters: BaseFilters | None,
@@ -514,6 +548,14 @@ def handle_stream_message_objects(
emitter = get_default_emitter()
# Collect file IDs for the file reader tool to restrict access
available_file_ids = _collect_available_file_ids(
chat_history=chat_history,
project_id=chat_session.project_id,
user_id=user_id,
db_session=db_session,
)
# Construct tools based on the persona configurations
tool_dict = construct_tools(
persona=persona,
@@ -540,6 +582,9 @@ def handle_stream_message_objects(
additional_headers=custom_tool_additional_headers,
mcp_headers=mcp_headers,
),
file_reader_tool_config=FileReaderToolConfig(
available_file_ids=available_file_ids,
),
allowed_tool_ids=new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=project_search_config.search_usage,
)

View File

@@ -1,6 +1,7 @@
from typing import Type
from typing import Union
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FileReaderTool
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
@@ -25,6 +26,7 @@ BUILT_IN_TOOL_TYPES = Union[
KnowledgeGraphTool,
OpenURLTool,
PythonTool,
FileReaderTool,
]
BUILT_IN_TOOL_MAP: dict[str, Type[BUILT_IN_TOOL_TYPES]] = {
@@ -34,6 +36,7 @@ BUILT_IN_TOOL_MAP: dict[str, Type[BUILT_IN_TOOL_TYPES]] = {
KnowledgeGraphTool.__name__: KnowledgeGraphTool,
OpenURLTool.__name__: OpenURLTool,
PythonTool.__name__: PythonTool,
FileReaderTool.__name__: FileReaderTool,
}
STOPPING_TOOLS_NAMES: list[str] = [ImageGenerationTool.NAME]

View File

@@ -13,3 +13,7 @@ IMAGE_GENERATION_TOOL_ID = "ImageGenerationTool"
WEB_SEARCH_TOOL_ID = "WebSearchTool"
PYTHON_TOOL_ID = "PythonTool"
OPEN_URL_TOOL_ID = "OpenURLTool"
FILE_READER_TOOL_ID = "FileReaderTool"
# Tool names as referenced by tool results / tool calls (read_file)
FILE_READER_TOOL_NAME = "read_file"

View File

@@ -30,6 +30,7 @@ from onyx.tools.models import SearchToolUsage
from onyx.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FileReaderTool
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
@@ -57,6 +58,10 @@ class SearchToolConfig(BaseModel):
enable_slack_search: bool = True
class FileReaderToolConfig(BaseModel):
available_file_ids: list[UUID] | None = None
class CustomToolConfig(BaseModel):
chat_session_id: UUID | None = None
message_id: int | None = None
@@ -99,6 +104,7 @@ def construct_tools(
llm: LLM,
search_tool_config: SearchToolConfig | None = None,
custom_tool_config: CustomToolConfig | None = None,
file_reader_tool_config: FileReaderToolConfig | None = None,
allowed_tool_ids: list[int] | None = None,
search_usage_forcing_setting: SearchToolUsage = SearchToolUsage.AUTO,
) -> dict[int, list[Tool]]:
@@ -234,6 +240,21 @@ def construct_tools(
PythonTool(tool_id=db_tool_model.id, emitter=emitter)
]
# Handle File Reader Tool
elif tool_cls.__name__ == FileReaderTool.__name__:
tool_dict[db_tool_model.id] = [
FileReaderTool(
tool_id=db_tool_model.id,
emitter=emitter,
available_file_ids=(
file_reader_tool_config.available_file_ids
if file_reader_tool_config
and file_reader_tool_config.available_file_ids
else []
),
)
]
# Handle KG Tool
# TODO: disabling for now because it's broken in the refactor
# elif tool_cls.__name__ == KnowledgeGraphTool.__name__:

View File

@@ -0,0 +1,196 @@
from typing import Any
from typing import cast
from uuid import UUID
from sqlalchemy.orm import Session
from typing_extensions import override
from onyx.chat.emitter import Emitter
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_store.models import ChatFileType
from onyx.file_store.utils import load_user_file
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.interface import Tool
from onyx.tools.models import ToolCallException
from onyx.tools.models import ToolResponse
from onyx.utils.logger import setup_logger
logger = setup_logger()
FILE_ID_FIELD = "file_id"
START_CHAR_FIELD = "start_char"
NUM_CHARS_FIELD = "num_chars"
MAX_NUM_CHARS = 4000
DEFAULT_NUM_CHARS = MAX_NUM_CHARS
class FileReaderToolOverrideKwargs:
"""No override kwargs needed for the file reader tool."""
class FileReaderTool(Tool[FileReaderToolOverrideKwargs]):
NAME = "read_file"
DISPLAY_NAME = "File Reader"
DESCRIPTION = (
"Read a section of a user-uploaded file by character offset. "
"Returns up to 4000 characters starting from the given offset."
)
def __init__(
self,
tool_id: int,
emitter: Emitter,
available_file_ids: list[UUID],
) -> None:
super().__init__(emitter=emitter)
self._id = tool_id
self._available_file_ids = available_file_ids
@property
def id(self) -> int:
return self._id
@property
def name(self) -> str:
return self.NAME
@property
def description(self) -> str:
return self.DESCRIPTION
@property
def display_name(self) -> str:
return self.DISPLAY_NAME
@override
@classmethod
def is_available(cls, db_session: Session) -> bool: # noqa: ARG003
return True
def tool_definition(self) -> dict:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.DESCRIPTION,
"parameters": {
"type": "object",
"properties": {
FILE_ID_FIELD: {
"type": "string",
"description": "The UUID of the file to read.",
},
START_CHAR_FIELD: {
"type": "integer",
"description": (
"Character offset to start reading from. Defaults to 0."
),
},
NUM_CHARS_FIELD: {
"type": "integer",
"description": (
"Number of characters to return (max 4000). "
"Defaults to 4000."
),
},
},
"required": [FILE_ID_FIELD],
},
},
}
def emit_start(self, placement: Placement) -> None:
self.emitter.emit(
Packet(
placement=placement,
obj=CustomToolStart(tool_name=self.DISPLAY_NAME),
)
)
def _validate_file_id(self, raw_file_id: str) -> UUID:
try:
file_id = UUID(raw_file_id)
except ValueError:
raise ToolCallException(
message=f"Invalid file_id: {raw_file_id}",
llm_facing_message=f"'{raw_file_id}' is not a valid file UUID.",
)
if file_id not in self._available_file_ids:
raise ToolCallException(
message=f"File {file_id} not in available files",
llm_facing_message=(
f"File '{file_id}' is not available. "
"Please use one of the file IDs listed in the context."
),
)
return file_id
def run(
self,
placement: Placement, # noqa: ARG002
override_kwargs: FileReaderToolOverrideKwargs, # noqa: ARG002
**llm_kwargs: Any,
) -> ToolResponse:
if FILE_ID_FIELD not in llm_kwargs:
raise ToolCallException(
message=f"Missing required '{FILE_ID_FIELD}' parameter",
llm_facing_message=(
f"The read_file tool requires a '{FILE_ID_FIELD}' parameter. "
f'Example: {{"file_id": "abc-123", "start_char": 0, "num_chars": 4000}}'
),
)
raw_file_id = cast(str, llm_kwargs[FILE_ID_FIELD])
file_id = self._validate_file_id(raw_file_id)
start_char = max(0, int(llm_kwargs.get(START_CHAR_FIELD, 0)))
num_chars = min(
MAX_NUM_CHARS,
max(1, int(llm_kwargs.get(NUM_CHARS_FIELD, DEFAULT_NUM_CHARS))),
)
with get_session_with_current_tenant() as db_session:
chat_file = load_user_file(file_id, db_session)
# Only PLAIN_TEXT and CSV are guaranteed to contain actual text bytes.
# DOC type in a loaded file means plaintext extraction failed and the
# content is the original binary (e.g. raw PDF/DOCX bytes).
if chat_file.file_type not in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV):
raise ToolCallException(
message=f"File {file_id} is not a text file (type={chat_file.file_type})",
llm_facing_message=(
f"File '{chat_file.filename or file_id}' is a "
f"{chat_file.file_type.value} file and cannot be read as text."
),
)
try:
full_text = chat_file.content.decode("utf-8", errors="replace")
except Exception:
raise ToolCallException(
message=f"Failed to decode file {file_id}",
llm_facing_message="The file could not be read as text.",
)
total_chars = len(full_text)
end_char = min(start_char + num_chars, total_chars)
section = full_text[start_char:end_char]
has_more = end_char < total_chars
header = (
f"File: {chat_file.filename or file_id}\n"
f"Characters {start_char}-{end_char} of {total_chars}"
)
if has_more:
header += f" (use start_char={end_char} to continue reading)"
llm_response = f"{header}\n\n{section}"
return ToolResponse(
rich_response=None,
llm_facing_response=llm_response,
)

View File

@@ -50,6 +50,12 @@ EXPECTED_TOOLS = {
in_code_tool_id="ResearchAgent",
user_id=None,
),
"FileReaderTool": ToolSeedingExpectedResult(
name="read_file",
display_name="File Reader",
in_code_tool_id="FileReaderTool",
user_id=None,
),
}
@@ -92,10 +98,10 @@ def test_tool_seeding_migration() -> None:
)
tools = result.fetchall()
# Should have all 8 builtin tools
# Should have all 9 builtin tools
assert (
len(tools) == 8
), f"Should have created exactly 8 builtin tools, got {len(tools)}"
len(tools) == 9
), f"Should have created exactly 9 builtin tools, got {len(tools)}"
def validate_tool(expected: ToolSeedingExpectedResult) -> None:
tool = next((t for t in tools if t[1] == expected.name), None)
@@ -127,3 +133,6 @@ def test_tool_seeding_migration() -> None:
# Check ResearchAgent (Deep Research as a tool)
validate_tool(EXPECTED_TOOLS["ResearchAgent"])
# Check FileReaderTool
validate_tool(EXPECTED_TOOLS["FileReaderTool"])