Compare commits

...

1 Commits

Author SHA1 Message Date
pablodanswer
28d3771b2d k 2025-01-06 10:15:05 -08:00
4 changed files with 20 additions and 13 deletions

View File

@@ -654,6 +654,7 @@ def stream_chat_message_objects(
user=user,
llm=llm,
fast_llm=fast_llm,
tenant_id=tenant_id,
search_tool_config=SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,

View File

@@ -263,7 +263,7 @@ def setup_postgres(db_session: Session) -> None:
logger.notice("Loading default Prompts and Personas")
load_chat_yamls(db_session)
refresh_built_in_tools_cache(db_session)
refresh_built_in_tools_cache(db_session, tenant_id=None)
auto_add_search_tool_to_personas(db_session)
if GEN_AI_API_KEY and fetch_default_provider(db_session) is None:

View File

@@ -151,12 +151,12 @@ def auto_add_search_tool_to_personas(db_session: Session) -> None:
logger.notice("Completed adding SearchTool to relevant Personas.")
_built_in_tools_cache: dict[int, Type[Tool]] | None = None
_built_in_tools_cache: dict[str | None, dict[int, Type[Tool]]] = {}
def refresh_built_in_tools_cache(db_session: Session) -> None:
def refresh_built_in_tools_cache(db_session: Session, tenant_id: str | None) -> None:
global _built_in_tools_cache
_built_in_tools_cache = {}
_built_in_tools_cache[tenant_id] = {}
all_tool_built_in_tools = (
db_session.execute(
select(ToolDBModel).where(not_(ToolDBModel.in_code_tool_id.is_(None)))
@@ -174,22 +174,27 @@ def refresh_built_in_tools_cache(db_session: Session) -> None:
None,
)
if tool_info:
_built_in_tools_cache[tool.id] = tool_info["cls"]
_built_in_tools_cache[tenant_id][tool.id] = tool_info["cls"]
def get_built_in_tool_by_id(
tool_id: int, db_session: Session, force_refresh: bool = False
tool_id: int,
db_session: Session,
tenant_id: str | None,
force_refresh: bool = False,
) -> Type[Tool]:
global _built_in_tools_cache
if _built_in_tools_cache is None or force_refresh:
refresh_built_in_tools_cache(db_session)
if tenant_id not in _built_in_tools_cache or force_refresh:
refresh_built_in_tools_cache(db_session, tenant_id)
if _built_in_tools_cache is None:
if tenant_id not in _built_in_tools_cache:
raise RuntimeError(
"Built-in tools cache is None despite being refreshed. Should never happen."
f"Built-in tools cache for tenant {tenant_id} is None despite being refreshed. Should never happen."
)
if tool_id in _built_in_tools_cache:
return _built_in_tools_cache[tool_id]
if tool_id in _built_in_tools_cache[tenant_id]:
return _built_in_tools_cache[tenant_id][tool_id]
else:
raise ValueError(f"No built-in tool found in the cache with ID {tool_id}")
raise ValueError(
f"No built-in tool found in the cache with ID {tool_id} for tenant {tenant_id}"
)

View File

@@ -138,6 +138,7 @@ def construct_tools(
user: User | None,
llm: LLM,
fast_llm: LLM,
tenant_id: str | None,
search_tool_config: SearchToolConfig | None = None,
internet_search_tool_config: InternetSearchToolConfig | None = None,
image_generation_tool_config: ImageGenerationToolConfig | None = None,