Compare commits

...

1 Commits

Author SHA1 Message Date
Harsh Deep
36484bca06 Add old conversation context in chat 2025-09-22 13:18:01 -04:00
3 changed files with 103 additions and 19 deletions

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
@@ -35,14 +36,17 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.configs.agent_configs import AGENT_USE_SHORT_TERM_MEMORY
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import DocumentSourceDescription
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.db.chat import get_past_user_messages
from onyx.db.chat import update_db_session_with_messages
from onyx.db.connector import fetch_unique_document_sources
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.models import ChatMessage
from onyx.db.models import Tool
from onyx.db.tools import get_tools
from onyx.file_store.models import ChatFileType
@@ -331,6 +335,34 @@ ANY tool mentioned can be accessed through this generic tool. If in doubt, use t
}
def _build_chat_history_with_context(
current_chat_history: list[BaseMessage],
old_conversation_user_messages: list[ChatMessage] | None,
max_chat_history_messages: int,
) -> str:
"""Enrich the chat history string with the old conversation user messages."""
# TODO: Check context limits
# TODO: Look into adding both the user & assistant messages in past history
# TODO: Look into using the chat summary instead of user queries
# TODO: Allow user-side config instead of using environment variable
chat_history_string = (
get_chat_history_string(current_chat_history, max_chat_history_messages)
or "(No chat history yet available)"
)
if old_conversation_user_messages:
past_messages_str = "\n".join(
[
f"user: {msg.message}"
for msg in sorted(
old_conversation_user_messages, key=lambda x: x.time_sent
)
]
)
chat_history_string = f"Previous conversations:\n{past_messages_str}\n\nCurrent conversation:\n{chat_history_string}"
return chat_history_string
def clarifier(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> OrchestrationSetup:
@@ -375,6 +407,15 @@ def clarifier(
db_session = graph_config.persistence.db_session
active_source_types = fetch_unique_document_sources(db_session)
old_conversation_user_messages = (
get_past_user_messages(
db_session,
graph_config.persistence.chat_session_id,
)
if AGENT_USE_SHORT_TERM_MEMORY
else None
)
available_tools = _get_available_tools(
db_session, graph_config, kg_enabled, active_source_types
)
@@ -421,12 +462,10 @@ def clarifier(
assistant_system_prompt = PromptTemplate(DEFAULT_DR_SYSTEM_PROMPT).build()
assistant_task_prompt = ""
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
chat_history_string = _build_chat_history_with_context(
graph_config.inputs.prompt_builder.message_history,
old_conversation_user_messages,
MAX_CHAT_HISTORY_MESSAGES,
)
uploaded_text_context = (
@@ -625,12 +664,10 @@ def clarifier(
clarification, original_question, chat_history_string = result
else:
# generate clarification questions if needed
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
chat_history_string = _build_chat_history_with_context(
graph_config.inputs.prompt_builder.message_history,
old_conversation_user_messages,
MAX_CHAT_HISTORY_MESSAGES,
)
base_clarification_prompt = get_dr_prompt_orchestration_templates(
@@ -727,14 +764,11 @@ def clarifier(
db_session.commit()
else:
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
MAX_CHAT_HISTORY_MESSAGES,
)
or "(No chat history yet available)"
chat_history_string = _build_chat_history_with_context(
graph_config.inputs.prompt_builder.message_history,
old_conversation_user_messages,
MAX_CHAT_HISTORY_MESSAGES,
)
if (
clarification
and clarification.clarification_question

View File

@@ -387,3 +387,9 @@ TF_DR_TIMEOUT_SHORT = int(os.environ.get("TF_DR_TIMEOUT_SHORT") or 60)
TF_DR_DEFAULT_FAST = (os.environ.get("TF_DR_DEFAULT_FAST") or "False").lower() == "true"
GRAPH_VERSION_NAME: str = "a"
# Default to false, temporary will make a user preference to enable this
AGENT_USE_SHORT_TERM_MEMORY = (
os.environ.get("AGENT_USE_SHORT_TERM_MEMORY") or "False"
).lower() == "true"

View File

@@ -10,6 +10,7 @@ from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import not_
from sqlalchemy import nullsfirst
from sqlalchemy import or_
from sqlalchemy import Row
@@ -193,6 +194,49 @@ def get_chat_sessions_by_user(
return list(chat_sessions)
def get_past_user_messages(
db_session: Session,
current_chat_session_id: UUID,
limit: int = 10,
) -> list[ChatMessage]:
"""
Get the last N user messages from any chat session before the current session.
Args:
db_session: Database session
current_chat_session_id: ID of the current chat session
limit: Maximum number of user messages to return (default: 10)
Returns:
List of user messages, ordered by most recent first
"""
current_session = db_session.execute(
select(ChatSession).where(ChatSession.id == current_chat_session_id)
).scalar_one_or_none()
if not current_session:
logger.warning(f"Current chat session {current_chat_session_id} not found")
return []
# TODO: Fix this query, currently if there is a chat session with an edited
# message, both the original and edited messages will be returned need to
# only return the newest path in the conversation tree.
stmt = (
select(ChatMessage)
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
.where(
ChatMessage.message_type == MessageType.USER,
ChatSession.user_id == current_session.user_id,
ChatSession.id != current_chat_session_id,
not_(ChatSession.deleted),
)
.order_by(desc(ChatMessage.time_sent))
.limit(limit)
)
result = db_session.execute(stmt).scalars().all()
return list(result)
def delete_search_doc_message_relationship(
message_id: int, db_session: Session
) -> None: