mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
41 Commits
sharepoint
...
KG-prototy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
050c0133c7 | ||
|
|
a2bdbd23d8 | ||
|
|
f87d44d24e | ||
|
|
53ef95ec69 | ||
|
|
eb48354e8f | ||
|
|
7e12f02b62 | ||
|
|
221a4c19f0 | ||
|
|
2971eb7d59 | ||
|
|
e7e786fd65 | ||
|
|
baf4dd64b0 | ||
|
|
68853393ee | ||
|
|
defcf8291a | ||
|
|
2ef9d19160 | ||
|
|
3acc069511 | ||
|
|
336b31fc1e | ||
|
|
e22d414d33 | ||
|
|
fcd749ab29 | ||
|
|
365b9b09e3 | ||
|
|
862807c13c | ||
|
|
15a095a068 | ||
|
|
4600788476 | ||
|
|
99b6b7dd11 | ||
|
|
49847b05f8 | ||
|
|
be601d204a | ||
|
|
ea12c25282 | ||
|
|
458d7fb124 | ||
|
|
4391d05ce3 | ||
|
|
5ec5e616f1 | ||
|
|
2cc87c7d53 | ||
|
|
c017724e91 | ||
|
|
e99eac4a1d | ||
|
|
da4f348039 | ||
|
|
740d4a5a9d | ||
|
|
c7c8330b90 | ||
|
|
6869f0403d | ||
|
|
bacb1092ff | ||
|
|
1a119601e6 | ||
|
|
1980fc62c0 | ||
|
|
02a4232189 | ||
|
|
f1cc6841f9 | ||
|
|
e7d2f9a43a |
@@ -0,0 +1,219 @@
|
||||
"""create knowlege graph tables
|
||||
|
||||
Revision ID: 495cb26ce93e
|
||||
Revises: 6a804aeb4830
|
||||
Create Date: 2025-03-19 08:51:14.341989
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "495cb26ce93e"
|
||||
down_revision = "6a804aeb4830"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"kg_entity_type",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column("grounding", sa.String(), nullable=False),
|
||||
sa.Column("clustering", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column(
|
||||
"classification_requirements",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("cluster_count", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"extraction_sources", postgresql.JSONB, nullable=False, server_default="{}"
|
||||
),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.Column("grounded_source_name", sa.String(), nullable=True, unique=True),
|
||||
sa.Column(
|
||||
"ge_determine_instructions", postgresql.ARRAY(sa.String()), nullable=True
|
||||
),
|
||||
sa.Column("ge_grounding_signature", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
# Create KGRelationshipType table
|
||||
op.create_table(
|
||||
"kg_relationship_type",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column(
|
||||
"source_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column(
|
||||
"target_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column("clustering", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("cluster_count", sa.Integer(), nullable=True),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
)
|
||||
|
||||
# Create KGEntity table
|
||||
op.create_table(
|
||||
"kg_entity",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("document_id", sa.String(), nullable=True, index=True),
|
||||
sa.Column(
|
||||
"alternative_names",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"keywords",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("cluster_count", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
|
||||
),
|
||||
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
|
||||
)
|
||||
op.create_index("ix_entity_type_acl", "kg_entity", ["entity_type_id_name", "acl"])
|
||||
op.create_index(
|
||||
"ix_entity_name_search", "kg_entity", ["name", "entity_type_id_name"]
|
||||
)
|
||||
|
||||
# Create KGRelationship table
|
||||
op.create_table(
|
||||
"kg_relationship",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("source_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("cluster_count", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["source_node"], ["kg_entity.id_name"]),
|
||||
sa.ForeignKeyConstraint(["target_node"], ["kg_entity.id_name"]),
|
||||
sa.ForeignKeyConstraint(
|
||||
["relationship_type_id_name"], ["kg_relationship_type.id_name"]
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_source_target_type",
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kg_relationship_nodes", "kg_relationship", ["source_node", "target_node"]
|
||||
)
|
||||
|
||||
# Create KGTerm table
|
||||
op.create_table(
|
||||
"kg_term",
|
||||
sa.Column("id_term", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column(
|
||||
"entity_types",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
)
|
||||
op.create_index("ix_search_term_entities", "kg_term", ["entity_types"])
|
||||
op.create_index("ix_search_term_term", "kg_term", ["id_term"])
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("kg_processed", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("kg_data", postgresql.JSONB(), nullable=False, server_default="{}"),
|
||||
)
|
||||
op.add_column(
|
||||
"connector",
|
||||
sa.Column(
|
||||
"kg_extraction_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"document_by_connector_credential_pair",
|
||||
sa.Column("has_been_kg_processed", sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop tables in reverse order of creation to handle dependencies
|
||||
op.drop_table("kg_term")
|
||||
op.drop_table("kg_relationship")
|
||||
op.drop_table("kg_entity")
|
||||
op.drop_table("kg_relationship_type")
|
||||
op.drop_table("kg_entity_type")
|
||||
op.drop_column("connector", "kg_extraction_enabled")
|
||||
op.drop_column("document_by_connector_credential_pair", "has_been_kg_processed")
|
||||
op.drop_column("document", "kg_data")
|
||||
op.drop_column("document", "kg_processed")
|
||||
@@ -0,0 +1,18 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class KGQuestionEntityExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
terms: list[str]
|
||||
time_filter: str | None
|
||||
|
||||
|
||||
class KGQuestionRelationshipExtractionResult(BaseModel):
|
||||
relationships: list[str]
|
||||
|
||||
|
||||
class KGQuestionExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
terms: list[str]
|
||||
time_filter: str | None
|
||||
@@ -37,7 +37,7 @@ def research_object_source(
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
graph_config.inputs.search_request.query
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = graph_config.inputs.search_request.query
|
||||
graph_config.inputs.search_request.query
|
||||
object, document_source = state.object_source_combination
|
||||
|
||||
if search_tool is None or graph_config.inputs.search_request.persona is None:
|
||||
|
||||
@@ -17,6 +17,9 @@ def research(
|
||||
search_tool: SearchTool,
|
||||
document_sources: list[DocumentSource] | None = None,
|
||||
time_cutoff: datetime | None = None,
|
||||
kg_entities: list[str] | None = None,
|
||||
kg_relationships: list[str] | None = None,
|
||||
kg_terms: list[str] | None = None,
|
||||
) -> list[LlmDoc]:
|
||||
# new db session to avoid concurrency issues
|
||||
|
||||
@@ -33,6 +36,9 @@ def research(
|
||||
skip_query_analysis=True,
|
||||
document_sources=document_sources,
|
||||
time_cutoff=time_cutoff,
|
||||
kg_entities=kg_entities,
|
||||
kg_relationships=kg_relationships,
|
||||
kg_terms=kg_terms,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
|
||||
|
||||
|
||||
def simple_vs_search(
|
||||
state: MainState,
|
||||
) -> Literal["process_kg_only_answers", "construct_deep_search_filters"]:
|
||||
if state.strategy == KGAnswerStrategy.DEEP or len(state.relationships) > 0:
|
||||
return "construct_deep_search_filters"
|
||||
else:
|
||||
return "process_kg_only_answers"
|
||||
|
||||
|
||||
def research_individual_object(
|
||||
state: MainState,
|
||||
# ) -> list[Send | Hashable] | Literal["individual_deep_search"]:
|
||||
) -> list[Send | Hashable]:
|
||||
edge_start_time = datetime.now()
|
||||
|
||||
# if (
|
||||
# not state.div_con_entities
|
||||
# or not state.broken_down_question
|
||||
# or not state.vespa_filter_results
|
||||
# ):
|
||||
# return "individual_deep_search"
|
||||
|
||||
# else:
|
||||
|
||||
assert state.div_con_entities is not None
|
||||
assert state.broken_down_question is not None
|
||||
assert state.vespa_filter_results is not None
|
||||
|
||||
return [
|
||||
Send(
|
||||
"process_individual_deep_search",
|
||||
ResearchObjectInput(
|
||||
entity=entity,
|
||||
broken_down_question=state.broken_down_question,
|
||||
vespa_filter_results=state.vespa_filter_results,
|
||||
source_division=state.source_division,
|
||||
log_messages=[
|
||||
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for entity in state.div_con_entities
|
||||
]
|
||||
138
backend/onyx/agents/agent_search/kb_search/graph_builder.py
Normal file
138
backend/onyx/agents/agent_search/kb_search/graph_builder.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.kb_search.conditional_edges import (
|
||||
research_individual_object,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.conditional_edges import simple_vs_search
|
||||
from onyx.agents.agent_search.kb_search.nodes.a1_extract_ert import extract_ert
|
||||
from onyx.agents.agent_search.kb_search.nodes.a2_analyze import analyze
|
||||
from onyx.agents.agent_search.kb_search.nodes.a3_generate_simple_sql import (
|
||||
generate_simple_sql,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.b1_construct_deep_search_filters import (
|
||||
construct_deep_search_filters,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.b2p_process_individual_deep_search import (
|
||||
process_individual_deep_search,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.b3_consoldidate_individual_deep_search import (
|
||||
consoldidate_individual_deep_search,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.c1_process_kg_only_answers import (
|
||||
process_kg_only_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.d1_generate_answer import generate_answer
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
test_mode = False
|
||||
|
||||
|
||||
def kb_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the knowledge graph search process.
|
||||
"""
|
||||
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
"extract_ert",
|
||||
extract_ert,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"generate_simple_sql",
|
||||
generate_simple_sql,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"analyze",
|
||||
analyze,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"generate_answer",
|
||||
generate_answer,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"construct_deep_search_filters",
|
||||
construct_deep_search_filters,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"process_individual_deep_search",
|
||||
process_individual_deep_search,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# "individual_deep_search",
|
||||
# individual_deep_search,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"consoldidate_individual_deep_search",
|
||||
consoldidate_individual_deep_search,
|
||||
)
|
||||
|
||||
graph.add_node("process_kg_only_answers", process_kg_only_answers)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="extract_ert")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="extract_ert",
|
||||
end_key="analyze",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="analyze",
|
||||
end_key="generate_simple_sql",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges("generate_simple_sql", simple_vs_search)
|
||||
|
||||
graph.add_edge(start_key="process_kg_only_answers", end_key="generate_answer")
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="construct_deep_search_filters",
|
||||
# end_key="process_individual_deep_search",
|
||||
# )
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="construct_deep_search_filters",
|
||||
path=research_individual_object,
|
||||
path_map=["process_individual_deep_search"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="process_individual_deep_search",
|
||||
end_key="consoldidate_individual_deep_search",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="individual_deep_search",
|
||||
# end_key="consoldidate_individual_deep_search",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="consoldidate_individual_deep_search", end_key="generate_answer"
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
217
backend/onyx/agents/agent_search/kb_search/graph_utils.py
Normal file
217
backend/onyx/agents/agent_search/kb_search/graph_utils.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from typing import Set
|
||||
|
||||
from onyx.agents.agent_search.kb_search.models import KGExpendedGraphObjects
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.relationships import get_relationships_of_entity
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _check_entities_disconnected(
|
||||
current_entities: list[str], current_relationships: list[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if all entities in current_entities are disconnected via the given relationships.
|
||||
Relationships are in the format: source_entity__relationship_name__target_entity
|
||||
|
||||
Args:
|
||||
current_entities: List of entity IDs to check connectivity for
|
||||
current_relationships: List of relationships in format source__relationship__target
|
||||
|
||||
Returns:
|
||||
bool: True if all entities are disconnected, False otherwise
|
||||
"""
|
||||
if not current_entities:
|
||||
return True
|
||||
|
||||
# Create a graph representation using adjacency list
|
||||
graph: dict[str, set[str]] = {entity: set() for entity in current_entities}
|
||||
|
||||
# Build the graph from relationships
|
||||
for relationship in current_relationships:
|
||||
try:
|
||||
source, _, target = relationship.split("__")
|
||||
if source in graph and target in graph:
|
||||
graph[source].add(target)
|
||||
graph[target].add(source) # Add reverse edge since graph is undirected
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid relationship format: {relationship}")
|
||||
|
||||
# Use BFS to check if all entities are connected
|
||||
visited: set[str] = set()
|
||||
start_entity = current_entities[0]
|
||||
|
||||
def _bfs(start: str) -> None:
|
||||
queue = [start]
|
||||
visited.add(start)
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for neighbor in graph[current]:
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append(neighbor)
|
||||
|
||||
# Start BFS from the first entity
|
||||
_bfs(start_entity)
|
||||
|
||||
logger.debug(f"Number of visited entities: {len(visited)}")
|
||||
|
||||
# Check if all current_entities are in visited
|
||||
return not all(entity in visited for entity in current_entities)
|
||||
|
||||
|
||||
def create_minimal_connected_query_graph(
|
||||
entities: list[str], relationships: list[str], max_depth: int = 2
|
||||
) -> KGExpendedGraphObjects:
|
||||
"""
|
||||
Find the minimal subgraph that connects all input entities, using only general entities
|
||||
(<entity_type>:*) as intermediate nodes. The subgraph will include only the relationships
|
||||
necessary to connect all input entities through the shortest possible paths.
|
||||
|
||||
Args:
|
||||
entities: Initial list of entity IDs
|
||||
relationships: Initial list of relationships in format source__relationship__target
|
||||
max_depth: Maximum depth to expand the graph (default: 2)
|
||||
|
||||
Returns:
|
||||
KGExpendedGraphObjects containing expanded entities and relationships
|
||||
"""
|
||||
# Create copies of input lists to avoid modifying originals
|
||||
expanded_entities = entities.copy()
|
||||
expanded_relationships = relationships.copy()
|
||||
|
||||
# Keep track of original entities
|
||||
original_entities = set(entities)
|
||||
|
||||
# Build initial graph from existing relationships
|
||||
graph: dict[str, set[tuple[str, str]]] = {
|
||||
entity: set() for entity in expanded_entities
|
||||
}
|
||||
for rel in relationships:
|
||||
try:
|
||||
source, rel_name, target = rel.split("__")
|
||||
if source in graph and target in graph:
|
||||
graph[source].add((target, rel_name))
|
||||
graph[target].add((source, rel_name))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# For each depth level
|
||||
counter = 0
|
||||
while counter < max_depth:
|
||||
# Find all connected components in the current graph
|
||||
components = []
|
||||
visited = set()
|
||||
|
||||
def dfs(node: str, component: set[str]) -> None:
|
||||
visited.add(node)
|
||||
component.add(node)
|
||||
for neighbor, _ in graph.get(node, set()):
|
||||
if neighbor not in visited:
|
||||
dfs(neighbor, component)
|
||||
|
||||
# Find all components
|
||||
for entity in expanded_entities:
|
||||
if entity not in visited:
|
||||
component: Set[str] = set()
|
||||
dfs(entity, component)
|
||||
components.append(component)
|
||||
|
||||
# If we only have one component, we're done
|
||||
if len(components) == 1:
|
||||
break
|
||||
|
||||
# Find the shortest path between any two components using general entities
|
||||
shortest_path = None
|
||||
shortest_path_length = float("inf")
|
||||
|
||||
for comp1 in components:
|
||||
for comp2 in components:
|
||||
if comp1 == comp2:
|
||||
continue
|
||||
|
||||
# Try to find path between entities in different components
|
||||
for entity1 in comp1:
|
||||
if not any(e in original_entities for e in comp1):
|
||||
continue
|
||||
|
||||
# entity1_type = entity1.split(":")[0]
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entity1_rels = get_relationships_of_entity(db_session, entity1)
|
||||
|
||||
for rel1 in entity1_rels:
|
||||
try:
|
||||
source1, rel_name1, target1 = rel1.split("__")
|
||||
if source1 != entity1:
|
||||
continue
|
||||
|
||||
target1_type = target1.split(":")[0]
|
||||
general_target = f"{target1_type}:*"
|
||||
|
||||
# Try to find path from general_target to comp2
|
||||
for entity2 in comp2:
|
||||
if not any(e in original_entities for e in comp2):
|
||||
continue
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entity2_rels = get_relationships_of_entity(
|
||||
db_session, entity2
|
||||
)
|
||||
|
||||
for rel2 in entity2_rels:
|
||||
try:
|
||||
source2, rel_name2, target2 = rel2.split("__")
|
||||
if target2 != entity2:
|
||||
continue
|
||||
|
||||
source2_type = source2.split(":")[0]
|
||||
general_source = f"{source2_type}:*"
|
||||
|
||||
if general_target == general_source:
|
||||
# Found a path of length 2
|
||||
path = [
|
||||
(entity1, rel_name1, general_target),
|
||||
(general_target, rel_name2, entity2),
|
||||
]
|
||||
if len(path) < shortest_path_length:
|
||||
shortest_path = path
|
||||
shortest_path_length = len(path)
|
||||
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# If we found a path, add it to our graph
|
||||
if shortest_path:
|
||||
for source, rel_name, target in shortest_path:
|
||||
# Add general entity if needed
|
||||
if ":*" in source and source not in expanded_entities:
|
||||
expanded_entities.append(source)
|
||||
if ":*" in target and target not in expanded_entities:
|
||||
expanded_entities.append(target)
|
||||
|
||||
# Add relationship
|
||||
rel = f"{source}__{rel_name}__{target}"
|
||||
if rel not in expanded_relationships:
|
||||
expanded_relationships.append(rel)
|
||||
|
||||
# Update graph
|
||||
if source not in graph:
|
||||
graph[source] = set()
|
||||
if target not in graph:
|
||||
graph[target] = set()
|
||||
graph[source].add((target, rel_name))
|
||||
graph[target].add((source, rel_name))
|
||||
|
||||
counter += 1
|
||||
|
||||
logger.debug(f"Number of expanded entities: {len(expanded_entities)}")
|
||||
logger.debug(f"Number of expanded relationships: {len(expanded_relationships)}")
|
||||
|
||||
return KGExpendedGraphObjects(
|
||||
entities=expanded_entities, relationships=expanded_relationships
|
||||
)
|
||||
34
backend/onyx/agents/agent_search/kb_search/models.py
Normal file
34
backend/onyx/agents/agent_search/kb_search/models.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import YesNoEnum
|
||||
|
||||
|
||||
class KGQuestionEntityExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
terms: list[str]
|
||||
time_filter: str | None
|
||||
|
||||
|
||||
class KGAnswerApproach(BaseModel):
|
||||
strategy: KGAnswerStrategy
|
||||
format: KGAnswerFormat
|
||||
broken_down_question: str | None = None
|
||||
divide_and_conquer: YesNoEnum | None = None
|
||||
|
||||
|
||||
class KGQuestionRelationshipExtractionResult(BaseModel):
|
||||
relationships: list[str]
|
||||
|
||||
|
||||
class KGQuestionExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
terms: list[str]
|
||||
time_filter: str | None
|
||||
|
||||
|
||||
class KGExpendedGraphObjects(BaseModel):
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
@@ -0,0 +1,240 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
|
||||
from onyx.agents.agent_search.kb_search.models import (
|
||||
KGQuestionRelationshipExtractionResult,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import (
|
||||
ERTExtractionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.relationships import get_allowed_relationship_type_pairs
|
||||
from onyx.kg.extractions.extraction_processing import get_entity_types_str
|
||||
from onyx.kg.extractions.extraction_processing import get_relationship_types_str
|
||||
from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def extract_ert(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ERTExtractionUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
|
||||
# first four lines duplicates from generate_initial_answer
|
||||
question = graph_config.inputs.search_request.query
|
||||
today_date = datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
all_entity_types = get_entity_types_str(active=None)
|
||||
all_relationship_types = get_relationship_types_str(active=None)
|
||||
|
||||
### get the entities, terms, and filters
|
||||
|
||||
query_extraction_pre_prompt = QUERY_ENTITY_EXTRACTION_PROMPT.format(
|
||||
entity_types=all_entity_types,
|
||||
relationship_types=all_relationship_types,
|
||||
)
|
||||
|
||||
query_extraction_prompt = query_extraction_pre_prompt.replace(
|
||||
"---content---", question
|
||||
).replace("---today_date---", today_date)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=query_extraction_prompt,
|
||||
)
|
||||
]
|
||||
fast_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
entity_extraction_result = (
|
||||
KGQuestionEntityExtractionResult.model_validate_json(cleaned_response)
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
entity_extraction_result = KGQuestionEntityExtractionResult(
|
||||
entities=[],
|
||||
terms=[],
|
||||
time_filter="",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
entity_extraction_result = KGQuestionEntityExtractionResult(
|
||||
entities=[],
|
||||
terms=[],
|
||||
time_filter="",
|
||||
)
|
||||
|
||||
ert_entities_string = f"Entities: {entity_extraction_result.entities}\n"
|
||||
# ert_terms_string = f"Terms: {entity_extraction_result.terms}"
|
||||
# ert_time_filter_string = f"Time Filter: {entity_extraction_result.time_filter}\n"
|
||||
|
||||
### get the relationships
|
||||
|
||||
# find the relationship types that match the extracted entity types
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
allowed_relationship_pairs = get_allowed_relationship_type_pairs(
|
||||
db_session, entity_extraction_result.entities
|
||||
)
|
||||
|
||||
query_relationship_extraction_prompt = (
|
||||
QUERY_RELATIONSHIP_EXTRACTION_PROMPT.replace("---question---", question)
|
||||
.replace("---today_date---", today_date)
|
||||
.replace(
|
||||
"---relationship_type_options---",
|
||||
" - " + "\n - ".join(allowed_relationship_pairs),
|
||||
)
|
||||
.replace("---identified_entities---", ert_entities_string)
|
||||
.replace("---entity_types---", all_entity_types)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=query_relationship_extraction_prompt,
|
||||
)
|
||||
]
|
||||
fast_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
cleaned_response = cleaned_response.replace("{{", '{"')
|
||||
cleaned_response = cleaned_response.replace("}}", '"}')
|
||||
|
||||
try:
|
||||
relationship_extraction_result = (
|
||||
KGQuestionRelationshipExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
|
||||
relationships=[],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
|
||||
relationships=[],
|
||||
)
|
||||
|
||||
# ert_relationships_string = (
|
||||
# f"Relationships: {relationship_extraction_result.relationships}\n"
|
||||
# )
|
||||
|
||||
##
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=ert_entities_string,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=ert_relationships_string,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=ert_terms_string,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=ert_time_filter_string,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
# dispatch_main_answer_stop_info(0, writer)
|
||||
|
||||
return ERTExtractionUpdate(
|
||||
entities_types_str=all_entity_types,
|
||||
entities=entity_extraction_result.entities,
|
||||
relationships=relationship_extraction_result.relationships,
|
||||
terms=entity_extraction_result.terms,
|
||||
time_filter=entity_extraction_result.time_filter,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="extract entities terms",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
210
backend/onyx/agents/agent_search/kb_search/nodes/a2_analyze.py
Normal file
210
backend/onyx/agents/agent_search/kb_search/nodes/a2_analyze.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
create_minimal_connected_query_graph,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
|
||||
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import YesNoEnum
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.kg.clustering.normalizations import normalize_entities
|
||||
from onyx.kg.clustering.normalizations import normalize_relationships
|
||||
from onyx.kg.clustering.normalizations import normalize_terms
|
||||
from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def analyze(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> AnalysisUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
entities = state.entities
|
||||
relationships = state.relationships
|
||||
terms = state.terms
|
||||
time_filter = state.time_filter
|
||||
|
||||
normalized_entities = normalize_entities(entities)
|
||||
normalized_relationships = normalize_relationships(
|
||||
relationships, normalized_entities.entity_normalization_map
|
||||
)
|
||||
normalized_terms = normalize_terms(terms)
|
||||
normalized_time_filter = time_filter
|
||||
|
||||
# Expand the entities and relationships to make sure that entities are connected
|
||||
|
||||
graph_expansion = create_minimal_connected_query_graph(
|
||||
normalized_entities.entities,
|
||||
normalized_relationships.relationships,
|
||||
max_depth=2,
|
||||
)
|
||||
|
||||
query_graph_entities = graph_expansion.entities
|
||||
query_graph_relationships = graph_expansion.relationships
|
||||
|
||||
# Evaluate whether a search needs to be done after identifying all entities and relationships
|
||||
|
||||
strategy_generation_prompt = (
|
||||
STRATEGY_GENERATION_PROMPT.replace(
|
||||
"---entities---", "\n".join(query_graph_entities)
|
||||
)
|
||||
.replace("---relationships---", "\n".join(query_graph_relationships))
|
||||
.replace("---terms---", "\n".join(normalized_terms.terms))
|
||||
.replace("---question---", question)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=strategy_generation_prompt,
|
||||
)
|
||||
]
|
||||
# fast_llm = graph_config.tooling.fast_llm
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
20,
|
||||
# fast_llm.invoke,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=5,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
approach_extraction_result = KGAnswerApproach.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
strategy = approach_extraction_result.strategy
|
||||
output_format = approach_extraction_result.format
|
||||
broken_down_question = approach_extraction_result.broken_down_question
|
||||
divide_and_conquer = approach_extraction_result.divide_and_conquer
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
strategy = KGAnswerStrategy.DEEP
|
||||
output_format = KGAnswerFormat.TEXT
|
||||
broken_down_question = None
|
||||
divide_and_conquer = YesNoEnum.NO
|
||||
if strategy is None or output_format is None:
|
||||
raise ValueError(f"Invalid strategy: {cleaned_response}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in strategy generation: {e}")
|
||||
raise e
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece="\n".join(normalized_entities.entities),
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece="\n".join(normalized_relationships.relationships),
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece="\n".join(query_graph_entities),
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece="\n".join(query_graph_relationships),
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=strategy.value,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=output_format.value,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# dispatch_main_answer_stop_info(0, writer)
|
||||
|
||||
return AnalysisUpdate(
|
||||
normalized_core_entities=normalized_entities.entities,
|
||||
normalized_core_relationships=normalized_relationships.relationships,
|
||||
query_graph_entities=query_graph_entities,
|
||||
query_graph_relationships=query_graph_relationships,
|
||||
normalized_terms=normalized_terms.terms,
|
||||
normalized_time_filter=normalized_time_filter,
|
||||
strategy=strategy,
|
||||
broken_down_question=broken_down_question,
|
||||
output_format=output_format,
|
||||
divide_and_conquer=divide_and_conquer,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="analyze",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,224 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import SQLSimpleGenerationUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_PROMPT
|
||||
from onyx.prompts.kg_prompts import SQL_AGGREGATION_REMOVAL_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _sql_is_aggregate_query(sql_statement: str) -> bool:
|
||||
return any(
|
||||
agg_func in sql_statement.upper()
|
||||
for agg_func in ["COUNT(", "MAX(", "MIN(", "AVG(", "SUM("]
|
||||
)
|
||||
|
||||
|
||||
def _remove_aggregation(sql_statement: str, llm: LLM) -> str:
|
||||
"""
|
||||
Remove aggregate functions from the SQL statement.
|
||||
"""
|
||||
|
||||
sql_aggregation_removal_prompt = SQL_AGGREGATION_REMOVAL_PROMPT.replace(
|
||||
"---sql_statement---", sql_statement
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=sql_aggregation_removal_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=25,
|
||||
max_tokens=800,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
sql_statement = cleaned_response.split("SQL:")[1].strip()
|
||||
sql_statement = sql_statement.split(";")[0].strip() + ";"
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in strategy generation: {e}")
|
||||
raise e
|
||||
|
||||
return sql_statement
|
||||
|
||||
|
||||
def generate_simple_sql(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> SQLSimpleGenerationUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
entities_types_str = state.entities_types_str
|
||||
state.strategy
|
||||
state.output_format
|
||||
|
||||
simple_sql_prompt = (
|
||||
SIMPLE_SQL_PROMPT.replace("---entities_types---", entities_types_str)
|
||||
.replace("---question---", question)
|
||||
.replace("---query_entities---", "\n".join(state.query_graph_entities))
|
||||
.replace(
|
||||
"---query_relationships---", "\n".join(state.query_graph_relationships)
|
||||
)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=simple_sql_prompt,
|
||||
)
|
||||
]
|
||||
fast_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=25,
|
||||
max_tokens=800,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
sql_statement = cleaned_response.split("SQL:")[1].strip()
|
||||
sql_statement = sql_statement.split(";")[0].strip() + ";"
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
|
||||
# reasoning = cleaned_response.split("SQL:")[0].strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in strategy generation: {e}")
|
||||
raise e
|
||||
|
||||
if _sql_is_aggregate_query(sql_statement):
|
||||
individualized_sql_query = _remove_aggregation(sql_statement, llm=fast_llm)
|
||||
else:
|
||||
individualized_sql_query = None
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=reasoning,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=cleaned_response,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# CRITICAL: EXECUTION OF SQL NEEDS TO ME MADE SAFE FOR PRODUCTION
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
result = db_session.execute(text(sql_statement))
|
||||
# Handle scalar results (like COUNT)
|
||||
if sql_statement.upper().startswith("SELECT COUNT"):
|
||||
scalar_result = result.scalar()
|
||||
query_results = (
|
||||
[{"count": int(scalar_result) - 1}]
|
||||
if scalar_result is not None
|
||||
else []
|
||||
)
|
||||
else:
|
||||
# Handle regular row results
|
||||
rows = result.fetchall()
|
||||
query_results = [dict(row._mapping) for row in rows]
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing SQL query: {e}")
|
||||
|
||||
raise e
|
||||
|
||||
if (
|
||||
individualized_sql_query is not None
|
||||
and individualized_sql_query != sql_statement
|
||||
):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
result = db_session.execute(text(individualized_sql_query))
|
||||
# Handle scalar results (like COUNT)
|
||||
if individualized_sql_query.upper().startswith("SELECT COUNT"):
|
||||
scalar_result = result.scalar()
|
||||
individualized_query_results = (
|
||||
[{"count": int(scalar_result) - 1}]
|
||||
if scalar_result is not None
|
||||
else []
|
||||
)
|
||||
else:
|
||||
# Handle regular row results
|
||||
rows = result.fetchall()
|
||||
individualized_query_results = [dict(row._mapping) for row in rows]
|
||||
except Exception as e:
|
||||
# No stopping here, the individualized SQL query is not mandatory
|
||||
logger.error(f"Error executing Individualized SQL query: {e}")
|
||||
individualized_query_results = None
|
||||
|
||||
else:
|
||||
individualized_query_results = None
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=str(query_results),
|
||||
# level=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# dispatch_main_answer_stop_info(0, writer)
|
||||
|
||||
logger.info(f"query_results: {query_results}")
|
||||
|
||||
return SQLSimpleGenerationUpdate(
|
||||
sql_query=sql_statement,
|
||||
query_results=query_results,
|
||||
individualized_sql_query=individualized_sql_query,
|
||||
individualized_query_results=individualized_query_results,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="generate simple sql",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,158 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import KGVespaFilterResults
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_entity_types_with_grounded_source_name
|
||||
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def construct_deep_search_filters(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> DeepSearchFilterUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
|
||||
entities_types_str = state.entities_types_str
|
||||
entities = state.query_graph_entities
|
||||
relationships = state.query_graph_relationships
|
||||
simple_sql_query = state.sql_query
|
||||
|
||||
search_filter_construction_prompt = (
|
||||
SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
|
||||
"---entity_type_descriptions---",
|
||||
entities_types_str,
|
||||
)
|
||||
.replace(
|
||||
"---entity_filters---",
|
||||
"\n".join(entities),
|
||||
)
|
||||
.replace(
|
||||
"---relationship_filters---",
|
||||
"\n".join(relationships),
|
||||
)
|
||||
.replace(
|
||||
"---sql_query---",
|
||||
simple_sql_query or "(no SQL generated)",
|
||||
)
|
||||
.replace(
|
||||
"---question---",
|
||||
question,
|
||||
)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=search_filter_construction_prompt,
|
||||
)
|
||||
]
|
||||
llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
cleaned_response = cleaned_response.replace("{{", '{"')
|
||||
cleaned_response = cleaned_response.replace("}}", '"}')
|
||||
|
||||
try:
|
||||
vespa_filter_results = KGVespaFilterResults.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
vespa_filter_results = KGVespaFilterResults(
|
||||
entity_filters=[],
|
||||
relationship_filters=[],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
vespa_filter_results = KGVespaFilterResults(
|
||||
entity_filters=[],
|
||||
relationship_filters=[],
|
||||
)
|
||||
|
||||
if (
|
||||
state.individualized_query_results
|
||||
and len(state.individualized_query_results) > 0
|
||||
):
|
||||
div_con_entities = [
|
||||
x["id_name"]
|
||||
for x in state.individualized_query_results
|
||||
if x["id_name"] is not None and "*" not in x["id_name"]
|
||||
]
|
||||
elif state.query_results and len(state.query_results) > 0:
|
||||
div_con_entities = [
|
||||
x["id_name"]
|
||||
for x in state.query_results
|
||||
if x["id_name"] is not None and "*" not in x["id_name"]
|
||||
]
|
||||
else:
|
||||
div_con_entities = []
|
||||
|
||||
div_con_entities = list(set(div_con_entities))
|
||||
|
||||
logger.info(f"div_con_entities: {div_con_entities}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
double_grounded_entity_types = get_entity_types_with_grounded_source_name(
|
||||
db_session
|
||||
)
|
||||
|
||||
source_division = False
|
||||
|
||||
if div_con_entities:
|
||||
for entity_type in double_grounded_entity_types:
|
||||
if entity_type.grounded_source_name.lower() in div_con_entities[0].lower():
|
||||
source_division = True
|
||||
break
|
||||
else:
|
||||
raise ValueError("No div_con_entities found")
|
||||
|
||||
return DeepSearchFilterUpdate(
|
||||
vespa_filter_results=vespa_filter_results,
|
||||
div_con_entities=div_con_entities,
|
||||
source_division=source_division,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="construct deep search filters",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,133 @@
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import research
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.prompts.kg_prompts import KG_OBJECT_SOURCE_RESEARCH_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def process_individual_deep_search(
|
||||
state: ResearchObjectInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ResearchObjectUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = state.broken_down_question
|
||||
source_division = state.source_division
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError("search_tool is not provided")
|
||||
|
||||
object = state.entity.replace(":", ": ").lower()
|
||||
|
||||
if source_division:
|
||||
extended_question = question
|
||||
else:
|
||||
extended_question = f"{question} in regards to {object}"
|
||||
|
||||
kg_entity_filters = copy.deepcopy(
|
||||
state.vespa_filter_results.entity_filters + [state.entity]
|
||||
)
|
||||
kg_relationship_filters = copy.deepcopy(
|
||||
state.vespa_filter_results.relationship_filters
|
||||
)
|
||||
|
||||
logger.info("Research for object: " + object)
|
||||
logger.info(f"kg_entity_filters: {kg_entity_filters}")
|
||||
logger.info(f"kg_relationship_filters: {kg_relationship_filters}")
|
||||
|
||||
# Add random wait between 1-3 seconds
|
||||
# time.sleep(random.uniform(0, 3))
|
||||
|
||||
retrieved_docs = research(
|
||||
question=extended_question,
|
||||
kg_entities=kg_entity_filters,
|
||||
kg_relationships=kg_relationship_filters,
|
||||
search_tool=search_tool,
|
||||
)
|
||||
|
||||
document_texts_list = []
|
||||
for doc_num, doc in enumerate(retrieved_docs):
|
||||
chunk_text = "Document " + str(doc_num + 1) + ":\n" + doc.content
|
||||
document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
# Built prompt
|
||||
|
||||
datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
kg_object_source_research_prompt = KG_OBJECT_SOURCE_RESEARCH_PROMPT.format(
|
||||
question=extended_question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
|
||||
# Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=kg_object_source_research_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
# fast_llm = graph_config.tooling.fast_llm
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
llm = primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
30,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
object_research_results = str(llm_response.content).replace("```json\n", "")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in research_object_source: {e}")
|
||||
|
||||
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
|
||||
|
||||
return ResearchObjectUpdate(
|
||||
research_object_results=[
|
||||
{
|
||||
"object": object.replace(":", ": ").capitalize(),
|
||||
"results": object_research_results,
|
||||
}
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="process individual deep search",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,125 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import KGVespaFilterResults
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def individual_deep_search(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> DeepSearchFilterUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
|
||||
entities_types_str = state.entities_types_str
|
||||
entities = state.query_graph_entities
|
||||
relationships = state.query_graph_relationships
|
||||
simple_sql_query = state.sql_query
|
||||
|
||||
search_filter_construction_prompt = (
|
||||
SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
|
||||
"---entity_type_descriptions---",
|
||||
entities_types_str,
|
||||
)
|
||||
.replace(
|
||||
"---entity_filters---",
|
||||
"\n".join(entities),
|
||||
)
|
||||
.replace(
|
||||
"---relationship_filters---",
|
||||
"\n".join(relationships),
|
||||
)
|
||||
.replace(
|
||||
"---sql_query---",
|
||||
simple_sql_query or "(no SQL generated)",
|
||||
)
|
||||
.replace(
|
||||
"---question---",
|
||||
question,
|
||||
)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=search_filter_construction_prompt,
|
||||
)
|
||||
]
|
||||
llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
cleaned_response = cleaned_response.replace("{{", '{"')
|
||||
cleaned_response = cleaned_response.replace("}}", '"}')
|
||||
|
||||
try:
|
||||
vespa_filter_results = KGVespaFilterResults.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
vespa_filter_results = KGVespaFilterResults(
|
||||
entity_filters=[],
|
||||
relationship_filters=[],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
vespa_filter_results = KGVespaFilterResults(
|
||||
entity_filters=[],
|
||||
relationship_filters=[],
|
||||
)
|
||||
|
||||
if state.query_results:
|
||||
div_con_entities = [
|
||||
x["id_name"] for x in state.query_results if x["id_name"] is not None
|
||||
]
|
||||
else:
|
||||
div_con_entities = []
|
||||
|
||||
return DeepSearchFilterUpdate(
|
||||
vespa_filter_results=vespa_filter_results,
|
||||
div_con_entities=div_con_entities,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="construct deep search filters",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,45 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def consoldidate_individual_deep_search(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ConsolidatedResearchUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
graph_config.inputs.search_request.query
|
||||
state.entities_types_str
|
||||
|
||||
research_object_results = state.research_object_results
|
||||
|
||||
consolidated_research_object_results_str = "\n".join(
|
||||
[f"{x['object']}: {x['results']}" for x in research_object_results]
|
||||
)
|
||||
|
||||
return ConsolidatedResearchUpdate(
|
||||
consolidated_research_object_results_str=consolidated_research_object_results_str,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="generate simple sql",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,107 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import ReferenceResults
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.document import get_base_llm_doc_information
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _general_format(result: dict[str, Any]) -> str:
|
||||
name = result.get("name")
|
||||
entity_type_id_name = result.get("entity_type_id_name")
|
||||
result.get("id_name")
|
||||
|
||||
if entity_type_id_name:
|
||||
return f"{entity_type_id_name.capitalize()}: {name}"
|
||||
else:
|
||||
return f"{name}"
|
||||
|
||||
|
||||
def _generate_reference_results(
|
||||
individualized_query_results: list[dict[str, Any]]
|
||||
) -> ReferenceResults:
|
||||
"""
|
||||
Generate reference results from the query results data string.
|
||||
"""
|
||||
|
||||
citations: list[str] = []
|
||||
general_entities = []
|
||||
|
||||
# get all entities that correspond to an Onu=yx document
|
||||
document_ids: list[str] = [
|
||||
cast(str, x.get("document_id"))
|
||||
for x in individualized_query_results
|
||||
if x.get("document_id")
|
||||
]
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
llm_doc_information_results = get_base_llm_doc_information(
|
||||
session, document_ids
|
||||
)
|
||||
|
||||
for llm_doc_information_result in llm_doc_information_results:
|
||||
citations.append(llm_doc_information_result.center_chunk.semantic_identifier)
|
||||
|
||||
for result in individualized_query_results:
|
||||
document_id: str | None = result.get("document_id")
|
||||
|
||||
if not document_id:
|
||||
general_entities.append(_general_format(result))
|
||||
|
||||
return ReferenceResults(citations=citations, general_entities=general_entities)
|
||||
|
||||
|
||||
def process_kg_only_answers(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ResultsDataUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
graph_config.inputs.search_request.query
|
||||
query_results = state.query_results
|
||||
individualized_query_results = state.individualized_query_results
|
||||
|
||||
query_results_list = []
|
||||
|
||||
if query_results:
|
||||
for query_result in query_results:
|
||||
query_results_list.append(str(query_result).replace(":", ": ").capitalize())
|
||||
else:
|
||||
raise ValueError("No query results were found")
|
||||
|
||||
query_results_data_str = "\n".join(query_results_list)
|
||||
|
||||
if individualized_query_results:
|
||||
reference_results = _generate_reference_results(individualized_query_results)
|
||||
else:
|
||||
reference_results = None
|
||||
|
||||
return ResultsDataUpdate(
|
||||
query_results_data_str=query_results_data_str,
|
||||
individualized_query_results_data_str="",
|
||||
reference_results=reference_results,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="kg query results data processing",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,157 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import MainOutput
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_EXAMPLES_PROMPT
|
||||
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT
|
||||
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_answer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> MainOutput:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
state.entities_types_str
|
||||
introductory_answer = state.query_results_data_str
|
||||
state.reference_results
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
if search_tool is None:
|
||||
raise ValueError("Search tool is not set")
|
||||
|
||||
consolidated_research_object_results_str = (
|
||||
state.consolidated_research_object_results_str
|
||||
)
|
||||
|
||||
question = graph_config.inputs.search_request.query
|
||||
|
||||
output_format = state.output_format
|
||||
if state.reference_results:
|
||||
examples = (
|
||||
state.reference_results.citations
|
||||
or state.reference_results.general_entities
|
||||
or []
|
||||
)
|
||||
research_results = "\n".join([f"- {example}" for example in examples])
|
||||
elif consolidated_research_object_results_str:
|
||||
research_results = consolidated_research_object_results_str
|
||||
else:
|
||||
research_results = ""
|
||||
|
||||
if research_results and introductory_answer:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_PROMPT.replace("---question---", question)
|
||||
.replace("---introductory_answer---", introductory_answer)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
.replace("---research_results---", research_results)
|
||||
)
|
||||
|
||||
elif not research_results and introductory_answer:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
|
||||
.replace("---introductory_answer---", introductory_answer)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
)
|
||||
elif research_results and not introductory_answer:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT.replace("---question---", question)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
.replace("---research_results---", research_results)
|
||||
)
|
||||
else:
|
||||
raise ValueError("No research results or introductory answer provided")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=output_format_prompt,
|
||||
)
|
||||
]
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
dispatch_timings: list[float] = []
|
||||
response: list[str] = []
|
||||
|
||||
def stream_answer() -> list[str]:
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=1000,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# logger.debug(f"Answer piece: {content}")
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
30,
|
||||
stream_answer,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not generate the answer. Error {e}")
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_ANSWER,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
dispatch_main_answer_stop_info(0, writer)
|
||||
|
||||
return MainOutput(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="query completed",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
42
backend/onyx/agents/agent_search/kb_search/ops.py
Normal file
42
backend/onyx/agents/agent_search/kb_search/ops.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
def research(
|
||||
question: str,
|
||||
search_tool: SearchTool,
|
||||
document_sources: list[DocumentSource] | None = None,
|
||||
time_cutoff: datetime | None = None,
|
||||
) -> list[LlmDoc]:
|
||||
# new db session to avoid concurrency issues
|
||||
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
retrieved_docs: list[LlmDoc] = []
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=False,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
document_sources=document_sources,
|
||||
time_cutoff=time_cutoff,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
|
||||
break
|
||||
return retrieved_docs
|
||||
127
backend/onyx/agents/agent_search/kb_search/states.py
Normal file
127
backend/onyx/agents/agent_search/kb_search/states.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from enum import Enum
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import ReferenceResults
|
||||
|
||||
|
||||
### States ###
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class KGVespaFilterResults(BaseModel):
|
||||
entity_filters: list[str]
|
||||
relationship_filters: list[str]
|
||||
|
||||
|
||||
class KGAnswerStrategy(Enum):
|
||||
DEEP = "DEEP"
|
||||
SIMPLE = "SIMPLE"
|
||||
|
||||
|
||||
class KGAnswerFormat(Enum):
|
||||
LIST = "LIST"
|
||||
TEXT = "TEXT"
|
||||
|
||||
|
||||
class YesNoEnum(str, Enum):
|
||||
YES = "yes"
|
||||
NO = "no"
|
||||
|
||||
|
||||
class AnalysisUpdate(LoggerUpdate):
|
||||
normalized_core_entities: list[str] = []
|
||||
normalized_core_relationships: list[str] = []
|
||||
query_graph_entities: list[str] = []
|
||||
query_graph_relationships: list[str] = []
|
||||
normalized_terms: list[str] = []
|
||||
normalized_time_filter: str | None = None
|
||||
strategy: KGAnswerStrategy | None = None
|
||||
output_format: KGAnswerFormat | None = None
|
||||
broken_down_question: str | None = None
|
||||
divide_and_conquer: YesNoEnum | None = None
|
||||
|
||||
|
||||
class SQLSimpleGenerationUpdate(LoggerUpdate):
|
||||
sql_query: str | None = None
|
||||
query_results: list[Dict[Any, Any]] | None = None
|
||||
individualized_sql_query: str | None = None
|
||||
individualized_query_results: list[Dict[Any, Any]] | None = None
|
||||
|
||||
|
||||
class ConsolidatedResearchUpdate(LoggerUpdate):
|
||||
consolidated_research_object_results_str: str | None = None
|
||||
|
||||
|
||||
class DeepSearchFilterUpdate(LoggerUpdate):
|
||||
vespa_filter_results: KGVespaFilterResults | None = None
|
||||
div_con_entities: list[str] | None = None
|
||||
source_division: bool | None = None
|
||||
|
||||
|
||||
class ResearchObjectOutput(LoggerUpdate):
|
||||
research_object_results: Annotated[list[dict[str, Any]], add] = []
|
||||
|
||||
|
||||
class ERTExtractionUpdate(LoggerUpdate):
|
||||
entities_types_str: str = ""
|
||||
entities: list[str] = []
|
||||
relationships: list[str] = []
|
||||
terms: list[str] = []
|
||||
time_filter: str | None = None
|
||||
|
||||
|
||||
class ResultsDataUpdate(LoggerUpdate):
|
||||
query_results_data_str: str | None = None
|
||||
individualized_query_results_data_str: str | None = None
|
||||
reference_results: ReferenceResults | None = None
|
||||
|
||||
|
||||
class ResearchObjectUpdate(LoggerUpdate):
|
||||
research_object_results: Annotated[list[dict[str, Any]], add] = []
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
ERTExtractionUpdate,
|
||||
AnalysisUpdate,
|
||||
SQLSimpleGenerationUpdate,
|
||||
ResultsDataUpdate,
|
||||
ResearchObjectOutput,
|
||||
DeepSearchFilterUpdate,
|
||||
ResearchObjectUpdate,
|
||||
ConsolidatedResearchUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
|
||||
|
||||
class ResearchObjectInput(LoggerUpdate):
|
||||
entity: str
|
||||
broken_down_question: str
|
||||
vespa_filter_results: KGVespaFilterResults
|
||||
source_division: bool | None
|
||||
@@ -12,12 +12,16 @@ from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
|
||||
divide_and_conquer_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.graph_builder import dc_graph_builder
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
|
||||
from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||
main_graph_builder as main_graph_builder_a,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
MainInput as MainInput,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
@@ -86,7 +90,7 @@ def _parse_agent_event(
|
||||
def manage_sync_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
graph_input: BasicInput | MainInput | DCMainInput,
|
||||
graph_input: BasicInput | MainInput | KBMainInput | DCMainInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
message_id = config.persistence.message_id if config.persistence else None
|
||||
for event in compiled_graph.stream(
|
||||
@@ -100,7 +104,7 @@ def manage_sync_streaming(
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
input: BasicInput | MainInput | DCMainInput,
|
||||
input: BasicInput | MainInput | KBMainInput | DCMainInput,
|
||||
) -> AnswerStream:
|
||||
config.behavior.perform_initial_search_decomposition = (
|
||||
INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
@@ -150,6 +154,15 @@ def run_basic_graph(
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_kb_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
graph = kb_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = KBMainInput(log_messages=[])
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_dc_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
|
||||
@@ -27,6 +27,8 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_sub_question_answer_prompt(
|
||||
question: str,
|
||||
|
||||
@@ -159,3 +159,9 @@ BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
class QueryExpansionType(Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
|
||||
|
||||
class ReferenceResults(BaseModel):
|
||||
# citations: list[InferenceSection]
|
||||
citations: list[str]
|
||||
general_entities: list[str]
|
||||
|
||||
@@ -11,6 +11,8 @@ from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import run_dc_graph
|
||||
from onyx.agents.agent_search.run_graph import run_kb_graph
|
||||
from onyx.agents.agent_search.run_graph import run_dc_graph
|
||||
from onyx.agents.agent_search.run_graph import run_main_graph
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
@@ -145,13 +147,21 @@ class Answer:
|
||||
|
||||
if self.graph_config.behavior.use_agentic_search:
|
||||
run_langgraph = run_main_graph
|
||||
elif (
|
||||
elif self.graph_config.inputs.search_request.persona:
|
||||
if (
|
||||
self.graph_config.inputs.search_request.persona
|
||||
and self.graph_config.inputs.search_request.persona.description.startswith(
|
||||
"DivCon Beta Agent"
|
||||
"DivCon Beta Agent"
|
||||
)
|
||||
):
|
||||
run_langgraph = run_dc_graph
|
||||
):
|
||||
run_langgraph = run_dc_graph
|
||||
elif self.graph_config.inputs.search_request.persona.name.startswith(
|
||||
"KG Dev"
|
||||
):
|
||||
run_langgraph = run_kb_graph
|
||||
else:
|
||||
run_langgraph = run_basic_graph
|
||||
|
||||
else:
|
||||
run_langgraph = run_basic_graph
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
@@ -100,6 +101,39 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
)
|
||||
|
||||
|
||||
def inference_section_from_llm_doc(llm_doc: LlmDoc) -> InferenceSection:
|
||||
# Create a center chunk first
|
||||
center_chunk = InferenceChunk(
|
||||
document_id=llm_doc.document_id,
|
||||
chunk_id=0, # Default to 0 since LlmDoc doesn't have this info
|
||||
content=llm_doc.content,
|
||||
blurb=llm_doc.blurb,
|
||||
semantic_identifier=llm_doc.semantic_identifier,
|
||||
source_type=llm_doc.source_type,
|
||||
metadata=llm_doc.metadata,
|
||||
updated_at=llm_doc.updated_at,
|
||||
source_links=llm_doc.source_links or {},
|
||||
match_highlights=llm_doc.match_highlights or [],
|
||||
section_continuation=False,
|
||||
image_file_name=None,
|
||||
title=None,
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=None,
|
||||
hidden=False,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
)
|
||||
|
||||
# Create InferenceSection with the center chunk
|
||||
# Since we don't have access to the original chunks, we'll use an empty list
|
||||
return InferenceSection(
|
||||
center_chunk=center_chunk,
|
||||
chunks=[], # Original surrounding chunks are not available in LlmDoc
|
||||
combined_content=llm_doc.content,
|
||||
)
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
|
||||
@@ -100,6 +100,8 @@ from onyx.file_store.utils import load_all_chat_files
|
||||
from onyx.file_store.utils import load_all_user_file_files
|
||||
from onyx.file_store.utils import load_all_user_files
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.kg.clustering.clustering import kg_clustering
|
||||
from onyx.kg.extractions.extraction_processing import kg_extraction
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
@@ -664,6 +666,17 @@ def stream_chat_message_objects(
|
||||
|
||||
llm: LLM
|
||||
|
||||
index_str = "danswer_chunk_text_embedding_3_small"
|
||||
|
||||
if new_msg_req.message == "ee":
|
||||
kg_extraction(tenant_id, index_str)
|
||||
|
||||
raise Exception("Extractions done")
|
||||
|
||||
elif new_msg_req.message == "cc":
|
||||
kg_clustering(tenant_id, index_str)
|
||||
raise Exception("Clustering done")
|
||||
|
||||
try:
|
||||
# Move these variables inside the try block
|
||||
file_id_to_user_file = {}
|
||||
|
||||
12
backend/onyx/configs/kg_configs.py
Normal file
12
backend/onyx/configs/kg_configs.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
KG_OWN_EMAIL_DOMAINS: list[str] = json.loads(
|
||||
os.environ.get("KG_OWN_EMAIL_DOMAINS", "[]")
|
||||
) # must be list
|
||||
|
||||
KG_IGNORE_EMAIL_DOMAINS: list[str] = json.loads(
|
||||
os.environ.get("KG_IGNORE_EMAIL_DOMAINS", "[]")
|
||||
) # must be list
|
||||
|
||||
KG_OWN_COMPANY: str = os.environ.get("KG_OWN_COMPANY", "")
|
||||
@@ -113,6 +113,9 @@ class BaseFilters(BaseModel):
|
||||
tags: list[Tag] | None = None
|
||||
user_file_ids: list[int] | None = None
|
||||
user_folder_ids: list[int] | None = None
|
||||
kg_entities: list[str] | None = None
|
||||
kg_relationships: list[str] | None = None
|
||||
kg_terms: list[str] | None = None
|
||||
|
||||
|
||||
class IndexFilters(BaseFilters):
|
||||
|
||||
@@ -68,6 +68,7 @@ class SearchPipeline:
|
||||
):
|
||||
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
|
||||
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
|
||||
|
||||
self.search_request = search_request
|
||||
self.user = user
|
||||
self.llm = llm
|
||||
|
||||
@@ -182,6 +182,9 @@ def retrieval_preprocessing(
|
||||
tags=preset_filters.tags, # Tags are never auto-extracted
|
||||
access_control_list=user_acl_filters,
|
||||
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
|
||||
kg_entities=preset_filters.kg_entities,
|
||||
kg_relationships=preset_filters.kg_relationships,
|
||||
kg_terms=preset_filters.kg_terms,
|
||||
)
|
||||
|
||||
llm_evaluation_type = LLMEvaluationType.BASIC
|
||||
|
||||
@@ -349,6 +349,8 @@ def retrieve_chunks(
|
||||
) -> list[InferenceChunk]:
|
||||
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
|
||||
|
||||
logger.info(f"RETRIEVAL CHUNKS query: {query}")
|
||||
|
||||
multilingual_expansion = get_multilingual_expansion(db_session)
|
||||
# Don't do query expansion on complex queries, rephrasings likely would not work well
|
||||
if not multilingual_expansion or "\n" in query.query or "\r" in query.query:
|
||||
|
||||
@@ -334,3 +334,21 @@ def mark_ccpair_with_indexing_trigger(
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_unprocessed_connector_ids(db_session: Session) -> list[int]:
|
||||
"""
|
||||
Retrieves a list of connector IDs that have not been KG processed for a given tenant.
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
Returns:
|
||||
list[int]: List of connector IDs that have enabled KG extraction but have unprocessed documents
|
||||
"""
|
||||
try:
|
||||
stmt = select(Connector.id).where(Connector.kg_extraction_enabled)
|
||||
result = db_session.execute(stmt)
|
||||
return [row[0] for row in result.fetchall()]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching unprocessed connector IDs: {str(e)}")
|
||||
raise e
|
||||
|
||||
@@ -23,6 +23,8 @@ from sqlalchemy.sql.expression import null
|
||||
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
@@ -377,6 +379,7 @@ def upsert_documents(
|
||||
last_modified=datetime.now(timezone.utc),
|
||||
primary_owners=doc.primary_owners,
|
||||
secondary_owners=doc.secondary_owners,
|
||||
kg_processed=False,
|
||||
)
|
||||
)
|
||||
for doc in seen_documents.values()
|
||||
@@ -843,3 +846,213 @@ def fetch_chunk_count_for_document(
|
||||
) -> int | None:
|
||||
stmt = select(DbDocument.chunk_count).where(DbDocument.id == document_id)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_unprocessed_kg_documents_for_connector(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
batch_size: int = 100,
|
||||
) -> Generator[DbDocument, None, None]:
|
||||
"""
|
||||
Retrieves all documents associated with a connector that have not yet been processed
|
||||
for knowledge graph extraction. Uses a generator pattern to handle large result sets.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
connector_id (int): The ID of the connector to check
|
||||
batch_size (int): Number of documents to fetch per batch, defaults to 100
|
||||
|
||||
Yields:
|
||||
DbDocument: Documents that haven't been KG processed, one at a time
|
||||
"""
|
||||
offset = 0
|
||||
while True:
|
||||
stmt = (
|
||||
select(DbDocument)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DbDocument.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
or_(
|
||||
DocumentByConnectorCredentialPair.has_been_kg_processed.is_(
|
||||
None
|
||||
),
|
||||
DocumentByConnectorCredentialPair.has_been_kg_processed.is_(
|
||||
False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
.limit(batch_size)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
batch = list(db_session.scalars(stmt).all())
|
||||
if not batch:
|
||||
break
|
||||
|
||||
for document in batch:
|
||||
yield document
|
||||
|
||||
offset += batch_size
|
||||
|
||||
|
||||
def get_kg_processed_document_ids(db_session: Session) -> list[str]:
|
||||
"""
|
||||
Retrieves all document IDs where kg_processed is True.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
|
||||
Returns:
|
||||
list[str]: List of document IDs that have been KG processed
|
||||
"""
|
||||
stmt = select(DbDocument.id).where(DbDocument.kg_processed.is_(True))
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def update_document_kg_info(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
kg_processed: bool,
|
||||
kg_data: dict,
|
||||
) -> None:
|
||||
"""Updates the knowledge graph related information for a document.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
document_id (str): The ID of the document to update
|
||||
kg_processed (bool): Whether the document has been processed for KG extraction
|
||||
kg_data (dict): Dictionary containing KG data with 'entities', 'relationships', and 'terms' keys
|
||||
|
||||
Raises:
|
||||
ValueError: If the document with the given ID is not found
|
||||
"""
|
||||
stmt = (
|
||||
update(DbDocument)
|
||||
.where(DbDocument.id == document_id)
|
||||
.values(
|
||||
kg_processed=kg_processed,
|
||||
kg_data=kg_data,
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def get_document_kg_info(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
) -> tuple[bool, dict] | None:
|
||||
"""Retrieves the knowledge graph processing status and data for a document.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
document_id (str): The ID of the document to query
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[bool, dict]]: A tuple containing:
|
||||
- bool: Whether the document has been KG processed
|
||||
- dict: The KG data containing 'entities', 'relationships', and 'terms'
|
||||
Returns None if the document is not found
|
||||
"""
|
||||
stmt = select(DbDocument.kg_processed, DbDocument.kg_data).where(
|
||||
DbDocument.id == document_id
|
||||
)
|
||||
result = db_session.execute(stmt).one_or_none()
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return result.kg_processed, result.kg_data or {}
|
||||
|
||||
|
||||
def get_all_kg_processed_documents_info(
|
||||
db_session: Session,
|
||||
) -> list[tuple[str, dict]]:
|
||||
"""Retrieves the knowledge graph data for all documents that have been processed.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, dict]]: A list of tuples containing:
|
||||
- str: The document ID
|
||||
- dict: The KG data containing 'entities', 'relationships', and 'terms'
|
||||
Only returns documents where kg_processed is True
|
||||
"""
|
||||
stmt = (
|
||||
select(DbDocument.id, DbDocument.kg_data)
|
||||
.where(DbDocument.kg_processed.is_(True))
|
||||
.order_by(DbDocument.id)
|
||||
)
|
||||
|
||||
results = db_session.execute(stmt).all()
|
||||
return [(str(doc_id), kg_data or {}) for doc_id, kg_data in results]
|
||||
|
||||
|
||||
def get_base_llm_doc_information(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> list[InferenceSection]:
|
||||
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
|
||||
results = db_session.execute(stmt).all()
|
||||
|
||||
inference_sections = []
|
||||
|
||||
for doc in results:
|
||||
bare_doc = doc[0]
|
||||
inference_section = InferenceSection(
|
||||
center_chunk=InferenceChunk(
|
||||
document_id=bare_doc.id,
|
||||
chunk_id=0,
|
||||
source_type=DocumentSource.NOT_APPLICABLE,
|
||||
semantic_identifier=bare_doc.semantic_id,
|
||||
title=None,
|
||||
boost=0,
|
||||
recency_bias=0,
|
||||
score=0,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
blurb="",
|
||||
content="",
|
||||
source_links=None,
|
||||
image_file_name=None,
|
||||
section_continuation=False,
|
||||
match_highlights=[],
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
updated_at=None,
|
||||
),
|
||||
chunks=[],
|
||||
combined_content="",
|
||||
)
|
||||
|
||||
inference_sections.append(inference_section)
|
||||
return inference_sections
|
||||
|
||||
|
||||
def get_document_updated_at(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> datetime | None:
|
||||
"""Retrieves the doc_updated_at timestamp for a given document ID.
|
||||
|
||||
Args:
|
||||
document_id (str): The ID of the document to query
|
||||
db_session (Session): The database session to use
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: The doc_updated_at timestamp if found, None if document doesn't exist
|
||||
"""
|
||||
if len(document_id.split(":")) == 2:
|
||||
document_id = document_id.split(":")[1]
|
||||
elif len(document_id.split(":")) > 2:
|
||||
raise ValueError(f"Invalid document ID: {document_id}")
|
||||
else:
|
||||
pass
|
||||
|
||||
stmt = select(DbDocument.doc_updated_at).where(DbDocument.id == document_id)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
271
backend/onyx/db/entities.py
Normal file
271
backend/onyx/db/entities.py
Normal file
@@ -0,0 +1,271 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import literal_column
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import KGEntity
|
||||
from onyx.db.models import KGEntityType
|
||||
|
||||
|
||||
def get_entity_types(
|
||||
db_session: Session,
|
||||
active: bool | None = True,
|
||||
) -> list[KGEntityType]:
|
||||
# Query the database for all distinct entity types
|
||||
|
||||
if active is None:
|
||||
return db_session.query(KGEntityType).order_by(KGEntityType.id_name).all()
|
||||
|
||||
else:
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.active == active)
|
||||
.order_by(KGEntityType.id_name)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def add_entity(
|
||||
db_session: Session,
|
||||
entity_type: str,
|
||||
name: str,
|
||||
document_id: str | None = None,
|
||||
cluster_count: int = 0,
|
||||
event_time: datetime | None = None,
|
||||
) -> "KGEntity | None":
|
||||
"""Add a new entity to the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity_type: Type of the entity (must match an existing KGEntityType)
|
||||
name: Name of the entity
|
||||
cluster_count: Number of clusters this entity has been found
|
||||
|
||||
Returns:
|
||||
KGEntity: The created entity
|
||||
"""
|
||||
entity_type = entity_type.upper()
|
||||
name = name.title()
|
||||
id_name = f"{entity_type}:{name}"
|
||||
|
||||
# Create new entity
|
||||
stmt = (
|
||||
pg_insert(KGEntity)
|
||||
.values(
|
||||
id_name=id_name,
|
||||
entity_type_id_name=entity_type,
|
||||
document_id=document_id,
|
||||
name=name,
|
||||
cluster_count=cluster_count,
|
||||
event_time=event_time,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["id_name"],
|
||||
set_=dict(
|
||||
# Direct numeric addition without text()
|
||||
cluster_count=KGEntity.cluster_count
|
||||
+ literal_column("EXCLUDED.cluster_count"),
|
||||
# Keep other fields updated as before
|
||||
entity_type_id_name=entity_type,
|
||||
document_id=document_id,
|
||||
name=name,
|
||||
event_time=event_time,
|
||||
),
|
||||
)
|
||||
.returning(KGEntity)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt).scalar()
|
||||
return result
|
||||
|
||||
|
||||
def get_kg_entity_by_document(db: Session, document_id: str) -> KGEntity | None:
|
||||
"""
|
||||
Check if a document_id exists in the kg_entities table and return its id_name if found.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
document_id: The document ID to search for
|
||||
|
||||
Returns:
|
||||
The id_name of the matching KGEntity if found, None otherwise
|
||||
"""
|
||||
query = select(KGEntity).where(KGEntity.document_id == document_id)
|
||||
result = db.execute(query).scalar()
|
||||
return result
|
||||
|
||||
|
||||
def get_ungrounded_entities(db_session: Session) -> List[KGEntity]:
|
||||
"""Get all entities whose entity type has grounding = 'UE' (ungrounded entities).
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntity objects belonging to ungrounded entity types
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntity)
|
||||
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
|
||||
.filter(KGEntityType.grounding == "UE")
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_grounded_entities(db_session: Session) -> List[KGEntity]:
|
||||
"""Get all entities whose entity type has grounding = 'UE' (ungrounded entities).
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntity objects belonging to ungrounded entity types
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntity)
|
||||
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
|
||||
.filter(KGEntityType.grounding == "GE")
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_determined_grounded_entity_types(db_session: Session) -> List[KGEntityType]:
|
||||
"""Get all entity types that have non-null ge_determine_instructions.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have ge_determine_instructions defined
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.ge_determine_instructions.isnot(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_entity_types_with_grounded_source_name(
|
||||
db_session: Session,
|
||||
) -> List[KGEntityType]:
|
||||
"""Get all entity types that have non-null grounded_source_name.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have grounded_source_name defined
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.grounded_source_name.isnot(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_entity_types_with_grounding_signature(
|
||||
db_session: Session,
|
||||
) -> List[KGEntityType]:
|
||||
"""Get all entity types that have non-null ge_grounding_signature.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have ge_grounding_signature defined
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.ge_grounding_signature.isnot(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_ge_entities_by_types(
|
||||
db_session: Session, entity_types: List[str]
|
||||
) -> List[KGEntity]:
|
||||
"""Get all entities matching an entity_type.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity_types: List of entity types to filter by
|
||||
|
||||
Returns:
|
||||
List of KGEntity objects belonging to the specified entity types
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntity)
|
||||
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
|
||||
.filter(KGEntity.entity_type_id_name.in_(entity_types))
|
||||
.filter(KGEntityType.grounding == "GE")
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def delete_entities_by_id_names(db_session: Session, id_names: list[str]) -> int:
|
||||
"""
|
||||
Delete entities from the database based on a list of id_names.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
id_names: List of entity id_names to delete
|
||||
|
||||
Returns:
|
||||
Number of entities deleted
|
||||
"""
|
||||
deleted_count = (
|
||||
db_session.query(KGEntity)
|
||||
.filter(KGEntity.id_name.in_(id_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
db_session.flush() # Flush to ensure deletion is processed
|
||||
return deleted_count
|
||||
|
||||
|
||||
def get_entities_for_types(
|
||||
db_session: Session, entity_types: List[str]
|
||||
) -> List[KGEntity]:
|
||||
"""Get all entities that belong to the specified entity types.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity_types: List of entity type id_names to filter by
|
||||
|
||||
Returns:
|
||||
List of KGEntity objects belonging to the specified entity types
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntity)
|
||||
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
|
||||
.filter(KGEntity.entity_type_id_name.in_(entity_types))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_entity_type_by_grounded_source_name(
|
||||
db_session: Session, grounded_source_name: str
|
||||
) -> KGEntityType | None:
|
||||
"""Get an entity type by its grounded_source_name and return it as a dictionary.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
grounded_source_name: The grounded_source_name of the entity to retrieve
|
||||
|
||||
Returns:
|
||||
Dictionary containing the entity's data with column names as keys,
|
||||
or None if the entity is not found
|
||||
"""
|
||||
entity_type = (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.grounded_source_name == grounded_source_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if entity_type is None:
|
||||
return None
|
||||
|
||||
return entity_type
|
||||
@@ -53,6 +53,7 @@ from onyx.db.enums import (
|
||||
SyncType,
|
||||
SyncStatus,
|
||||
)
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.constants import TokenRateLimitScope
|
||||
@@ -586,6 +587,22 @@ class Document(Base):
|
||||
)
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# tables for the knowledge graph data
|
||||
kg_processed: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
comment="Whether this document has been processed for knowledge graph extraction",
|
||||
)
|
||||
|
||||
kg_data: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Knowledge graph data extracted from this document",
|
||||
)
|
||||
|
||||
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="document"
|
||||
)
|
||||
@@ -604,6 +621,304 @@ class Document(Base):
|
||||
)
|
||||
|
||||
|
||||
class KGEntityType(Base):
|
||||
__tablename__ = "kg_entity_type"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, nullable=False, index=True
|
||||
)
|
||||
|
||||
description: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
|
||||
|
||||
grounding: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=False, index=False
|
||||
)
|
||||
grounded_source_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=False, index=False
|
||||
)
|
||||
|
||||
ge_determine_instructions: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True, default=None
|
||||
)
|
||||
|
||||
ge_grounding_signature: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=False, default=None
|
||||
)
|
||||
|
||||
clustering: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Clustering information for this entity type",
|
||||
)
|
||||
|
||||
classification_requirements: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Pre-extraction classification requirements and instructions",
|
||||
)
|
||||
|
||||
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
extraction_sources: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
comment="Sources and methods used to extract this entity",
|
||||
)
|
||||
|
||||
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class KGRelationshipType(Base):
|
||||
__tablename__ = "kg_relationship_type"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
source_entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
target_entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
definition: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
comment="Whether this relationship type represents a definition",
|
||||
)
|
||||
|
||||
clustering: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Clustering information for this relationship type",
|
||||
)
|
||||
|
||||
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# Relationships to EntityType
|
||||
source_type: Mapped["KGEntityType"] = relationship(
|
||||
"KGEntityType",
|
||||
foreign_keys=[source_entity_type_id_name],
|
||||
backref="source_relationship_type",
|
||||
)
|
||||
target_type: Mapped["KGEntityType"] = relationship(
|
||||
"KGEntityType",
|
||||
foreign_keys=[target_entity_type_id_name],
|
||||
backref="target_relationship_type",
|
||||
)
|
||||
|
||||
|
||||
class KGEntity(Base):
|
||||
__tablename__ = "kg_entity"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString, primary_key=True, index=True
|
||||
)
|
||||
|
||||
# Basic entity information
|
||||
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
document_id: Mapped[str | None] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=True
|
||||
)
|
||||
|
||||
alternative_names: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Reference to KGEntityType
|
||||
entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationship to KGEntityType
|
||||
entity_type: Mapped["KGEntityType"] = relationship("KGEntityType", backref="entity")
|
||||
|
||||
description: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
keywords: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Access control
|
||||
acl: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Boosts - using JSON for flexibility
|
||||
boosts: Mapped[dict] = mapped_column(postgresql.JSONB, nullable=False, default=dict)
|
||||
|
||||
event_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
comment="Time of the event being processed",
|
||||
)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Fixed column names in indexes
|
||||
Index("ix_entity_type_acl", entity_type_id_name, acl),
|
||||
Index("ix_entity_name_search", name, entity_type_id_name),
|
||||
)
|
||||
|
||||
|
||||
class KGRelationship(Base):
|
||||
__tablename__ = "kg_relationship"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString, primary_key=True, index=True
|
||||
)
|
||||
|
||||
# Source and target nodes (foreign keys to Entity table)
|
||||
source_node: Mapped[str] = mapped_column(
|
||||
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
|
||||
)
|
||||
|
||||
target_node: Mapped[str] = mapped_column(
|
||||
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Relationship type
|
||||
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
# Add new relationship type reference
|
||||
relationship_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_relationship_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Add the SQLAlchemy relationship property
|
||||
relationship_type: Mapped["KGRelationshipType"] = relationship(
|
||||
"KGRelationshipType", backref="relationship"
|
||||
)
|
||||
|
||||
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# Relationships to Entity table
|
||||
source: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[source_node])
|
||||
target: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[target_node])
|
||||
|
||||
__table_args__ = (
|
||||
# Index for querying relationships by type
|
||||
Index("ix_kg_relationship_type", type),
|
||||
# Composite index for source/target queries
|
||||
Index("ix_kg_relationship_nodes", source_node, target_node),
|
||||
# Ensure unique relationships between nodes of a specific type
|
||||
UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_source_target_type",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class KGTerm(Base):
|
||||
__tablename__ = "kg_term"
|
||||
|
||||
# Make id_term the primary key
|
||||
id_term: Mapped[str] = mapped_column(
|
||||
NullFilteredString, primary_key=True, nullable=False, index=True
|
||||
)
|
||||
|
||||
# List of entity types this term applies to
|
||||
entity_types: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Index for searching terms with specific entity types
|
||||
Index("ix_search_term_entities", entity_types),
|
||||
# Index for term lookups
|
||||
Index("ix_search_term_term", id_term),
|
||||
)
|
||||
|
||||
|
||||
class ChunkStats(Base):
|
||||
__tablename__ = "chunk_stats"
|
||||
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
|
||||
@@ -691,6 +1006,14 @@ class Connector(Base):
|
||||
indexing_start: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime, nullable=True
|
||||
)
|
||||
|
||||
kg_extraction_enabled: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
comment="Whether this connector should extract knowledge graph entities",
|
||||
)
|
||||
|
||||
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -1134,6 +1457,12 @@ class DocumentByConnectorCredentialPair(Base):
|
||||
# the actual indexing is complete
|
||||
has_been_indexed: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
has_been_kg_processed: Mapped[bool | None] = mapped_column(
|
||||
Boolean,
|
||||
nullable=True,
|
||||
comment="Whether this document has been processed for knowledge graph extraction",
|
||||
)
|
||||
|
||||
connector: Mapped[Connector] = relationship(
|
||||
"Connector", back_populates="documents_by_connector", passive_deletes=True
|
||||
)
|
||||
|
||||
298
backend/onyx/db/relationships.py
Normal file
298
backend/onyx/db/relationships.py
Normal file
@@ -0,0 +1,298 @@
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import KGRelationship
|
||||
from onyx.db.models import KGRelationshipType
|
||||
from onyx.kg.utils.formatting_utils import format_entity
|
||||
from onyx.kg.utils.formatting_utils import format_relationship
|
||||
from onyx.kg.utils.formatting_utils import generate_relationship_type
|
||||
|
||||
|
||||
def add_relationship(
|
||||
db_session: Session,
|
||||
relationship_id_name: str,
|
||||
cluster_count: int | None = None,
|
||||
) -> "KGRelationship":
|
||||
"""
|
||||
Add a relationship between two entities to the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
source_entity_id: ID of the source entity
|
||||
relationship_type: Type of relationship
|
||||
target_entity_id: ID of the target entity
|
||||
cluster_count: Optional count of similar relationships clustered together
|
||||
|
||||
Returns:
|
||||
The created KGRelationship object
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.IntegrityError: If the relationship already exists or entities don't exist
|
||||
"""
|
||||
# Generate a unique ID for the relationship
|
||||
|
||||
(
|
||||
source_entity_id_name,
|
||||
relationship_string,
|
||||
target_entity_id_name,
|
||||
) = relationship_id_name.split("__")
|
||||
|
||||
source_entity_id_name = format_entity(source_entity_id_name)
|
||||
target_entity_id_name = format_entity(target_entity_id_name)
|
||||
relationship_id_name = format_relationship(relationship_id_name)
|
||||
relationship_type = generate_relationship_type(relationship_id_name)
|
||||
|
||||
# Create new relationship
|
||||
relationship = KGRelationship(
|
||||
id_name=relationship_id_name,
|
||||
source_node=source_entity_id_name,
|
||||
target_node=target_entity_id_name,
|
||||
type=relationship_string.lower(),
|
||||
relationship_type_id_name=relationship_type,
|
||||
cluster_count=cluster_count,
|
||||
)
|
||||
|
||||
db_session.add(relationship)
|
||||
db_session.flush() # Flush to get any DB errors early
|
||||
|
||||
return relationship
|
||||
|
||||
|
||||
def add_or_increment_relationship(
|
||||
db_session: Session,
|
||||
relationship_id_name: str,
|
||||
) -> "KGRelationship":
|
||||
"""
|
||||
Add a relationship between two entities to the database if it doesn't exist,
|
||||
or increment its cluster_count by 1 if it already exists.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
relationship_id_name: The ID name of the relationship in format "source__relationship__target"
|
||||
|
||||
Returns:
|
||||
The created or updated KGRelationship object
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.IntegrityError: If there's an error with the database operation
|
||||
"""
|
||||
# Format the relationship_id_name
|
||||
relationship_id_name = format_relationship(relationship_id_name)
|
||||
|
||||
# Check if the relationship already exists
|
||||
existing_relationship = (
|
||||
db_session.query(KGRelationship)
|
||||
.filter(KGRelationship.id_name == relationship_id_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_relationship:
|
||||
# If it exists, increment the cluster_count
|
||||
existing_relationship.cluster_count = (
|
||||
existing_relationship.cluster_count or 0
|
||||
) + 1
|
||||
db_session.flush()
|
||||
return existing_relationship
|
||||
else:
|
||||
# If it doesn't exist, add it with cluster_count=1
|
||||
return add_relationship(db_session, relationship_id_name, cluster_count=1)
|
||||
|
||||
|
||||
def add_relationship_type(
|
||||
db_session: Session,
|
||||
source_entity_type: str,
|
||||
relationship_type: str,
|
||||
target_entity_type: str,
|
||||
definition: bool = False,
|
||||
extraction_count: int = 0,
|
||||
) -> "KGRelationshipType":
|
||||
"""
|
||||
Add a new relationship type to the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
source_entity_type: Type of the source entity
|
||||
relationship_type: Type of relationship
|
||||
target_entity_type: Type of the target entity
|
||||
definition: Whether this relationship type represents a definition (default False)
|
||||
|
||||
Returns:
|
||||
The created KGRelationshipType object
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.IntegrityError: If the relationship type already exists
|
||||
"""
|
||||
|
||||
id_name = f"{source_entity_type.upper()}__{relationship_type}__{target_entity_type.upper()}"
|
||||
# Create new relationship type
|
||||
rel_type = KGRelationshipType(
|
||||
id_name=id_name,
|
||||
name=relationship_type,
|
||||
source_entity_type_id_name=source_entity_type.upper(),
|
||||
target_entity_type_id_name=target_entity_type.upper(),
|
||||
definition=definition,
|
||||
cluster_count=extraction_count,
|
||||
type=relationship_type, # Using the relationship_type as the type
|
||||
active=True, # Setting as active by default
|
||||
)
|
||||
|
||||
db_session.add(rel_type)
|
||||
db_session.flush() # Flush to get any DB errors early
|
||||
|
||||
return rel_type
|
||||
|
||||
|
||||
def get_all_relationship_types(db_session: Session) -> list["KGRelationshipType"]:
|
||||
"""
|
||||
Retrieve all relationship types from the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
List of KGRelationshipType objects
|
||||
"""
|
||||
return db_session.query(KGRelationshipType).all()
|
||||
|
||||
|
||||
def get_all_relationships(db_session: Session) -> list["KGRelationship"]:
|
||||
"""
|
||||
Retrieve all relationships from the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
List of KGRelationship objects
|
||||
"""
|
||||
return db_session.query(KGRelationship).all()
|
||||
|
||||
|
||||
def delete_relationships_by_id_names(db_session: Session, id_names: list[str]) -> int:
|
||||
"""
|
||||
Delete relationships from the database based on a list of id_names.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
id_names: List of relationship id_names to delete
|
||||
|
||||
Returns:
|
||||
Number of relationships deleted
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
|
||||
"""
|
||||
deleted_count = (
|
||||
db_session.query(KGRelationship)
|
||||
.filter(KGRelationship.id_name.in_(id_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
db_session.flush() # Flush to ensure deletion is processed
|
||||
return deleted_count
|
||||
|
||||
|
||||
def delete_relationship_types_by_id_names(
|
||||
db_session: Session, id_names: list[str]
|
||||
) -> int:
|
||||
"""
|
||||
Delete relationship types from the database based on a list of id_names.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
id_names: List of relationship type id_names to delete
|
||||
|
||||
Returns:
|
||||
Number of relationship types deleted
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
|
||||
"""
|
||||
deleted_count = (
|
||||
db_session.query(KGRelationshipType)
|
||||
.filter(KGRelationshipType.id_name.in_(id_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
db_session.flush() # Flush to ensure deletion is processed
|
||||
return deleted_count
|
||||
|
||||
|
||||
def get_relationships_for_entity_type_pairs(
|
||||
db_session: Session, entity_type_pairs: list[tuple[str, str]]
|
||||
) -> list["KGRelationshipType"]:
|
||||
"""
|
||||
Get relationship types from the database based on a list of entity type pairs.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
entity_type_pairs: List of tuples where each tuple contains (source_entity_type, target_entity_type)
|
||||
|
||||
Returns:
|
||||
List of KGRelationshipType objects where source and target types match the provided pairs
|
||||
"""
|
||||
|
||||
conditions = [
|
||||
(
|
||||
(KGRelationshipType.source_entity_type_id_name == source_type)
|
||||
& (KGRelationshipType.target_entity_type_id_name == target_type)
|
||||
)
|
||||
for source_type, target_type in entity_type_pairs
|
||||
]
|
||||
|
||||
return db_session.query(KGRelationshipType).filter(or_(*conditions)).all()
|
||||
|
||||
|
||||
def get_allowed_relationship_type_pairs(
|
||||
db_session: Session, entities: list[str]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get the allowed relationship pairs for the given entities.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
entities: List of entity type ID names to filter by
|
||||
|
||||
Returns:
|
||||
List of id_names from KGRelationshipType where both source and target entity types
|
||||
are in the provided entities list
|
||||
"""
|
||||
entity_types = list(set([entity.split(":")[0] for entity in entities]))
|
||||
|
||||
return [
|
||||
row[0]
|
||||
for row in (
|
||||
db_session.query(KGRelationshipType.id_name)
|
||||
.filter(KGRelationshipType.source_entity_type_id_name.in_(entity_types))
|
||||
.filter(KGRelationshipType.target_entity_type_id_name.in_(entity_types))
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def get_relationships_of_entity(db_session: Session, entity_id: str) -> List[str]:
|
||||
"""Get all relationship ID names where the given entity is either the source or target node.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity_id: ID of the entity to find relationships for
|
||||
|
||||
Returns:
|
||||
List of relationship ID names where the entity is either source or target
|
||||
"""
|
||||
return [
|
||||
row[0]
|
||||
for row in (
|
||||
db_session.query(KGRelationship.id_name)
|
||||
.filter(
|
||||
or_(
|
||||
KGRelationship.source_node == entity_id,
|
||||
KGRelationship.target_node == entity_id,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
]
|
||||
@@ -98,7 +98,7 @@ class VespaDocumentFields:
|
||||
understandable like this for now.
|
||||
"""
|
||||
|
||||
# all other fields except these 4 will always be left alone by the update request
|
||||
# all other fields except these 4 and knowledge graph will always be left alone by the update request
|
||||
access: DocumentAccess | None = None
|
||||
document_sets: set[str] | None = None
|
||||
boost: float | None = None
|
||||
|
||||
@@ -85,6 +85,25 @@ schema DANSWER_CHUNK_NAME {
|
||||
indexing: attribute
|
||||
}
|
||||
|
||||
# Separate array fields for knowledge graph data
|
||||
field kg_entities type weightedset<string> {
|
||||
rank: filter
|
||||
indexing: summary | attribute
|
||||
attribute: fast-search
|
||||
}
|
||||
|
||||
field kg_relationships type weightedset<string> {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
|
||||
field kg_terms type weightedset<string> {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
|
||||
# Needs to have a separate Attribute list for efficient filtering
|
||||
field metadata_list type array<string> {
|
||||
indexing: summary | attribute
|
||||
|
||||
@@ -166,18 +166,19 @@ def _get_chunks_via_visit_api(
|
||||
|
||||
# build the list of fields to retrieve
|
||||
field_set_list = (
|
||||
None
|
||||
if not field_names
|
||||
else [f"{index_name}:{field_name}" for field_name in field_names]
|
||||
[] if not field_names else [f"{field_name}" for field_name in field_names]
|
||||
)
|
||||
acl_fieldset_entry = f"{index_name}:{ACCESS_CONTROL_LIST}"
|
||||
acl_fieldset_entry = f"{ACCESS_CONTROL_LIST}"
|
||||
if (
|
||||
field_set_list
|
||||
and filters.access_control_list
|
||||
and acl_fieldset_entry not in field_set_list
|
||||
):
|
||||
field_set_list.append(acl_fieldset_entry)
|
||||
field_set = ",".join(field_set_list) if field_set_list else None
|
||||
if field_set_list:
|
||||
field_set = f"{index_name}:" + ",".join(field_set_list)
|
||||
else:
|
||||
field_set = None
|
||||
|
||||
# build filters
|
||||
selection = f"{index_name}.document_id=='{chunk_request.document_id}'"
|
||||
|
||||
@@ -17,6 +17,7 @@ from uuid import UUID
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
@@ -29,6 +30,7 @@ from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.document_index_utils import get_document_chunk_ids
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
@@ -100,6 +102,37 @@ class _VespaUpdateRequest:
|
||||
update_request: dict[str, dict]
|
||||
|
||||
|
||||
class KGVespaChunkUpdateRequest(BaseModel):
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
url: str
|
||||
update_request: dict[str, dict]
|
||||
|
||||
|
||||
class KGUChunkUpdateRequest(BaseModel):
|
||||
"""
|
||||
Update KG fields for a document
|
||||
"""
|
||||
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
core_entity: str
|
||||
entities: set[str] | None = None
|
||||
relationships: set[str] | None = None
|
||||
terms: set[str] | None = None
|
||||
|
||||
|
||||
class KGUDocumentUpdateRequest(BaseModel):
|
||||
"""
|
||||
Update KG fields for a document
|
||||
"""
|
||||
|
||||
document_id: str
|
||||
entities: set[str]
|
||||
relationships: set[str]
|
||||
terms: set[str]
|
||||
|
||||
|
||||
def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
@@ -504,6 +537,51 @@ class VespaIndex(DocumentIndex):
|
||||
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
@classmethod
|
||||
def _apply_kg_chunk_updates_batched(
|
||||
cls,
|
||||
updates: list[KGVespaChunkUpdateRequest],
|
||||
httpx_client: httpx.Client,
|
||||
batch_size: int = BATCH_SIZE,
|
||||
) -> None:
|
||||
"""Runs a batch of updates in parallel via the ThreadPoolExecutor."""
|
||||
|
||||
def _kg_update_chunk(
|
||||
update: KGVespaChunkUpdateRequest, http_client: httpx.Client
|
||||
) -> httpx.Response:
|
||||
# logger.debug(
|
||||
# f"Updating KG with request to {update.url} with body {update.update_request}"
|
||||
# )
|
||||
return http_client.put(
|
||||
update.url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update.update_request,
|
||||
)
|
||||
|
||||
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficient for
|
||||
# indexing / updates / deletes since we have to make a large volume of requests.
|
||||
|
||||
with (
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
|
||||
httpx_client as http_client,
|
||||
):
|
||||
for update_batch in batch_generator(updates, batch_size):
|
||||
future_to_document_id = {
|
||||
executor.submit(
|
||||
_kg_update_chunk,
|
||||
update,
|
||||
http_client,
|
||||
): update.document_id
|
||||
for update in update_batch
|
||||
}
|
||||
for future in concurrent.futures.as_completed(future_to_document_id):
|
||||
res = future.result()
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
|
||||
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
|
||||
|
||||
@@ -587,6 +665,89 @@ class VespaIndex(DocumentIndex):
|
||||
time.monotonic() - update_start,
|
||||
)
|
||||
|
||||
def kg_chunk_updates(
|
||||
self, kg_update_requests: list[KGUChunkUpdateRequest], tenant_id: str
|
||||
) -> None:
|
||||
def _get_general_entity(specific_entity: str) -> str:
|
||||
entity_type, entity_name = specific_entity.split(":")
|
||||
if entity_type != "*":
|
||||
return f"{entity_type}:*"
|
||||
else:
|
||||
return specific_entity
|
||||
|
||||
processed_updates_requests: list[KGVespaChunkUpdateRequest] = []
|
||||
logger.debug(f"Updating {len(kg_update_requests)} documents in Vespa")
|
||||
|
||||
update_start = time.monotonic()
|
||||
|
||||
# Build the _VespaUpdateRequest objects
|
||||
|
||||
for kg_update_request in kg_update_requests:
|
||||
kg_update_dict: dict[str, dict] = {"fields": {}}
|
||||
|
||||
implied_entities = set()
|
||||
if kg_update_request.relationships is not None:
|
||||
for kg_relationship in kg_update_request.relationships:
|
||||
kg_relationship_split = kg_relationship.split("__")
|
||||
if len(kg_relationship_split) == 3:
|
||||
implied_entities.add(kg_relationship_split[0])
|
||||
implied_entities.add(kg_relationship_split[2])
|
||||
# Keep this for now in case we want to also add the general entities
|
||||
# implied_entities.add(_get_general_entity(kg_relationship_split[0]))
|
||||
# implied_entities.add(_get_general_entity(kg_relationship_split[2]))
|
||||
|
||||
kg_update_dict["fields"]["kg_relationships"] = {
|
||||
"assign": {
|
||||
kg_relationship: 1
|
||||
for kg_relationship in kg_update_request.relationships
|
||||
}
|
||||
}
|
||||
|
||||
if kg_update_request.entities is not None or implied_entities:
|
||||
if kg_update_request.entities is None:
|
||||
kg_entities = implied_entities
|
||||
else:
|
||||
kg_entities = set(kg_update_request.entities)
|
||||
kg_entities.update(implied_entities)
|
||||
|
||||
kg_update_dict["fields"]["kg_entities"] = {
|
||||
"assign": {kg_entity: 1 for kg_entity in kg_entities}
|
||||
}
|
||||
|
||||
if kg_update_request.terms is not None:
|
||||
kg_update_dict["fields"]["kg_terms"] = {
|
||||
"assign": {kg_term: 1 for kg_term in kg_update_request.terms}
|
||||
}
|
||||
|
||||
if not kg_update_dict["fields"]:
|
||||
logger.error("Update request received but nothing to update")
|
||||
continue
|
||||
|
||||
doc_chunk_id = get_uuid_from_chunk_info(
|
||||
document_id=kg_update_request.document_id,
|
||||
chunk_id=kg_update_request.chunk_id,
|
||||
tenant_id=tenant_id,
|
||||
large_chunk_id=None,
|
||||
)
|
||||
|
||||
processed_updates_requests.append(
|
||||
KGVespaChunkUpdateRequest(
|
||||
document_id=kg_update_request.document_id,
|
||||
chunk_id=kg_update_request.chunk_id,
|
||||
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
|
||||
update_request=kg_update_dict,
|
||||
)
|
||||
)
|
||||
|
||||
with self.httpx_client_context as httpx_client:
|
||||
self._apply_kg_chunk_updates_batched(
|
||||
processed_updates_requests, httpx_client
|
||||
)
|
||||
logger.debug(
|
||||
"Finished updating Vespa documents in %.2f seconds",
|
||||
time.monotonic() - update_start,
|
||||
)
|
||||
|
||||
@retry(
|
||||
tries=3,
|
||||
delay=1,
|
||||
|
||||
82
backend/onyx/document_index/vespa/kg_interactions.py
Normal file
82
backend/onyx/document_index/vespa/kg_interactions.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
|
||||
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
|
||||
from onyx.document_index.vespa.index import IndexFilters
|
||||
from onyx.document_index.vespa.index import KGUChunkUpdateRequest
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
# from backend.onyx.chat.process_message import get_inference_chunks
|
||||
# from backend.onyx.document_index.vespa.index import VespaIndex
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class KGChunkInfo(BaseModel):
|
||||
kg_relationships: dict[str, int]
|
||||
kg_entities: dict[str, int]
|
||||
kg_terms: dict[str, int]
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def get_document_kg_info(
|
||||
document_id: str,
|
||||
index_name: str,
|
||||
filters: IndexFilters | None = None,
|
||||
) -> dict | None:
|
||||
"""
|
||||
Retrieve the kg_info attribute from a Vespa document by its document_id.
|
||||
|
||||
Args:
|
||||
document_id: The unique identifier of the document.
|
||||
index_name: The name of the Vespa index to query.
|
||||
filters: Optional access control filters to apply.
|
||||
|
||||
Returns:
|
||||
The kg_info dictionary if found, None otherwise.
|
||||
"""
|
||||
# Use the existing visit API infrastructure
|
||||
kg_doc_info: dict[int, KGChunkInfo] = {}
|
||||
|
||||
document_chunks = _get_chunks_via_visit_api(
|
||||
chunk_request=VespaChunkRequest(document_id=document_id),
|
||||
index_name=index_name,
|
||||
filters=filters or IndexFilters(access_control_list=None),
|
||||
field_names=["kg_relationships", "kg_entities", "kg_terms"],
|
||||
get_large_chunks=False,
|
||||
)
|
||||
|
||||
for chunk_id, document_chunk in enumerate(document_chunks):
|
||||
kg_chunk_info = KGChunkInfo(
|
||||
kg_relationships=document_chunk["fields"].get("kg_relationships", {}),
|
||||
kg_entities=document_chunk["fields"].get("kg_entities", {}),
|
||||
kg_terms=document_chunk["fields"].get("kg_terms", {}),
|
||||
)
|
||||
|
||||
kg_doc_info[chunk_id] = kg_chunk_info # TODO: check the chunk id is correct!
|
||||
|
||||
return kg_doc_info
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def update_kg_chunks_vespa_info(
|
||||
kg_update_requests: list[KGUChunkUpdateRequest],
|
||||
index_name: str,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
""" """
|
||||
# Use the existing visit API infrastructure
|
||||
vespa_index = VespaIndex(
|
||||
index_name=index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=False,
|
||||
multitenant=False,
|
||||
httpx_client=None,
|
||||
)
|
||||
|
||||
vespa_index.kg_chunk_updates(
|
||||
kg_update_requests=kg_update_requests, tenant_id=tenant_id
|
||||
)
|
||||
@@ -5,7 +5,6 @@ from datetime import timezone
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from onyx.document_index.vespa_constants import CHUNK_ID
|
||||
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
@@ -67,6 +66,29 @@ def build_vespa_filters(
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
|
||||
|
||||
def _build_kg_filter(
|
||||
kg_entities: list[str] | None,
|
||||
kg_relationships: list[str] | None,
|
||||
kg_terms: list[str] | None,
|
||||
) -> str:
|
||||
if not kg_entities and not kg_relationships and not kg_terms:
|
||||
return ""
|
||||
|
||||
filter_parts = []
|
||||
|
||||
# Process each filter type using the same pattern
|
||||
for filter_type, values in [
|
||||
("kg_entities", kg_entities),
|
||||
("kg_relationships", kg_relationships),
|
||||
("kg_terms", kg_terms),
|
||||
]:
|
||||
if values:
|
||||
filter_parts.append(
|
||||
" and ".join(f'({filter_type} contains "{val}") ' for val in values)
|
||||
)
|
||||
|
||||
return f"({' and '.join(filter_parts)}) and "
|
||||
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
@@ -106,6 +128,13 @@ def build_vespa_filters(
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
# KG filter
|
||||
filter_str += _build_kg_filter(
|
||||
kg_entities=filters.kg_entities,
|
||||
kg_relationships=filters.kg_relationships,
|
||||
kg_terms=filters.kg_terms,
|
||||
)
|
||||
|
||||
# Trim trailing " and "
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
|
||||
1022
backend/onyx/kg/clustering/clustering.py
Normal file
1022
backend/onyx/kg/clustering/clustering.py
Normal file
File diff suppressed because it is too large
Load Diff
229
backend/onyx/kg/clustering/normalizations.py
Normal file
229
backend/onyx/kg/clustering/normalizations.py
Normal file
@@ -0,0 +1,229 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from thefuzz import process # type: ignore
|
||||
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_entities_for_types
|
||||
from onyx.db.relationships import get_relationships_for_entity_type_pairs
|
||||
from onyx.kg.models import NormalizedEntities
|
||||
from onyx.kg.models import NormalizedRelationships
|
||||
from onyx.kg.models import NormalizedTerms
|
||||
from onyx.kg.utils.embeddings import encode_string_batch
|
||||
|
||||
|
||||
def _get_existing_normalized_entities(raw_entities: List[str]) -> List[str]:
|
||||
"""
|
||||
Get existing normalized entities from the database.
|
||||
"""
|
||||
|
||||
entity_types = list(set([entity.split(":")[0] for entity in raw_entities]))
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entities = get_entities_for_types(db_session, entity_types)
|
||||
|
||||
return [entity.id_name for entity in entities]
|
||||
|
||||
|
||||
def _get_existing_normalized_relationships(
|
||||
raw_relationships: List[str],
|
||||
) -> Dict[str, Dict[str, List[str]]]:
|
||||
"""
|
||||
Get existing normalized relationships from the database.
|
||||
"""
|
||||
|
||||
relationship_type_map: Dict[str, Dict[str, List[str]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
relationship_pairs = list(
|
||||
set(
|
||||
[
|
||||
(
|
||||
relationship.split("__")[0].split(":")[0],
|
||||
relationship.split("__")[2].split(":")[0],
|
||||
)
|
||||
for relationship in raw_relationships
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
relationships = get_relationships_for_entity_type_pairs(
|
||||
db_session, relationship_pairs
|
||||
)
|
||||
|
||||
for relationship in relationships:
|
||||
relationship_type_map[relationship.source_entity_type_id_name][
|
||||
relationship.target_entity_type_id_name
|
||||
].append(relationship.id_name)
|
||||
|
||||
return relationship_type_map
|
||||
|
||||
|
||||
def normalize_entities(raw_entities: List[str]) -> NormalizedEntities:
|
||||
"""
|
||||
Match each entity against a list of normalized entities using fuzzy matching.
|
||||
Returns the best matching normalized entity for each input entity.
|
||||
|
||||
Args:
|
||||
entities: List of entity strings to normalize
|
||||
|
||||
Returns:
|
||||
List of normalized entity strings
|
||||
"""
|
||||
# Assume this is your predefined list of normalized entities
|
||||
norm_entities = _get_existing_normalized_entities(raw_entities)
|
||||
|
||||
normalized_results: List[str] = []
|
||||
normalized_map: Dict[str, str | None] = {}
|
||||
threshold = 80 # Adjust threshold as needed
|
||||
|
||||
for entity in raw_entities:
|
||||
# Find the best match and its score from norm_entities
|
||||
best_match, score = process.extractOne(entity, norm_entities)
|
||||
|
||||
if score >= threshold:
|
||||
normalized_results.append(best_match)
|
||||
normalized_map[entity] = best_match
|
||||
else:
|
||||
# If no good match found, keep original
|
||||
normalized_map[entity] = None
|
||||
|
||||
return NormalizedEntities(
|
||||
entities=normalized_results, entity_normalization_map=normalized_map
|
||||
)
|
||||
|
||||
|
||||
def normalize_relationships(
|
||||
raw_relationships: List[str], entity_normalization_map: Dict[str, Optional[str]]
|
||||
) -> NormalizedRelationships:
|
||||
"""
|
||||
Normalize relationships using entity mappings and relationship string matching.
|
||||
|
||||
Args:
|
||||
relationships: List of relationships in format "source__relation__target"
|
||||
entity_normalization_map: Mapping of raw entities to normalized ones (or None)
|
||||
|
||||
Returns:
|
||||
NormalizedRelationships containing normalized relationships and mapping
|
||||
"""
|
||||
# Placeholder for normalized relationship structure
|
||||
nor_relationships = _get_existing_normalized_relationships(raw_relationships)
|
||||
|
||||
normalized_rels: List[str] = []
|
||||
normalization_map: Dict[str, str | None] = {}
|
||||
|
||||
for raw_rel in raw_relationships:
|
||||
# 1. Split and normalize entities
|
||||
try:
|
||||
source, rel_string, target = raw_rel.split("__")
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid relationship format: {raw_rel}")
|
||||
|
||||
# Check if entities are in normalization map and not None
|
||||
norm_source = entity_normalization_map.get(source)
|
||||
norm_target = entity_normalization_map.get(target)
|
||||
|
||||
if norm_source is None or norm_target is None:
|
||||
normalization_map[raw_rel] = None
|
||||
continue
|
||||
|
||||
# 2. Find candidate normalized relationships
|
||||
candidate_rels = []
|
||||
norm_source_type = norm_source.split(":")[0]
|
||||
norm_target_type = norm_target.split(":")[0]
|
||||
if (
|
||||
norm_source_type in nor_relationships
|
||||
and norm_target_type in nor_relationships[norm_source_type]
|
||||
):
|
||||
candidate_rels = [
|
||||
rel.split("__")[1]
|
||||
for rel in nor_relationships[norm_source_type][norm_target_type]
|
||||
]
|
||||
|
||||
if not candidate_rels:
|
||||
normalization_map[raw_rel] = None
|
||||
continue
|
||||
|
||||
# 3. Encode and find best match
|
||||
strings_to_encode = [rel_string] + candidate_rels
|
||||
vectors = encode_string_batch(strings_to_encode)
|
||||
|
||||
# Get raw relation vector and candidate vectors
|
||||
raw_vector = vectors[0]
|
||||
candidate_vectors = vectors[1:]
|
||||
|
||||
# Calculate dot products
|
||||
dot_products = np.dot(candidate_vectors, raw_vector)
|
||||
best_match_idx = np.argmax(dot_products)
|
||||
|
||||
# Create normalized relationship
|
||||
norm_rel = f"{norm_source}__{candidate_rels[best_match_idx]}__{norm_target}"
|
||||
normalized_rels.append(norm_rel)
|
||||
normalization_map[raw_rel] = norm_rel
|
||||
|
||||
return NormalizedRelationships(
|
||||
relationships=normalized_rels, relationship_normalization_map=normalization_map
|
||||
)
|
||||
|
||||
|
||||
def normalize_terms(raw_terms: List[str]) -> NormalizedTerms:
|
||||
"""
|
||||
Normalize terms using semantic similarity matching.
|
||||
|
||||
Args:
|
||||
terms: List of terms to normalize
|
||||
|
||||
Returns:
|
||||
NormalizedTerms containing normalized terms and mapping
|
||||
"""
|
||||
# # Placeholder for normalized terms - this would typically come from a predefined list
|
||||
# normalized_term_list = [
|
||||
# "algorithm",
|
||||
# "database",
|
||||
# "software",
|
||||
# "programming",
|
||||
# # ... other normalized terms ...
|
||||
# ]
|
||||
|
||||
# normalized_terms: List[str] = []
|
||||
# normalization_map: Dict[str, str | None] = {}
|
||||
|
||||
# if not raw_terms:
|
||||
# return NormalizedTerms(terms=[], term_normalization_map={})
|
||||
|
||||
# # Encode all terms at once for efficiency
|
||||
# strings_to_encode = raw_terms + normalized_term_list
|
||||
# vectors = encode_string_batch(strings_to_encode)
|
||||
|
||||
# # Split vectors into query terms and candidate terms
|
||||
# query_vectors = vectors[:len(raw_terms)]
|
||||
# candidate_vectors = vectors[len(raw_terms):]
|
||||
|
||||
# # Calculate similarity for each term
|
||||
# for i, term in enumerate(raw_terms):
|
||||
# # Calculate dot products with all candidates
|
||||
# similarities = np.dot(candidate_vectors, query_vectors[i])
|
||||
# best_match_idx = np.argmax(similarities)
|
||||
# best_match_score = similarities[best_match_idx]
|
||||
|
||||
# # Use a threshold to determine if the match is good enough
|
||||
# if best_match_score > 0.7: # Adjust threshold as needed
|
||||
# normalized_term = normalized_term_list[best_match_idx]
|
||||
# normalized_terms.append(normalized_term)
|
||||
# normalization_map[term] = normalized_term
|
||||
# else:
|
||||
# # If no good match found, keep original
|
||||
# normalization_map[term] = None
|
||||
|
||||
# return NormalizedTerms(
|
||||
# terms=normalized_terms,
|
||||
# term_normalization_map=normalization_map
|
||||
# )
|
||||
|
||||
return NormalizedTerms(
|
||||
terms=raw_terms, term_normalization_map={term: term for term in raw_terms}
|
||||
)
|
||||
191
backend/onyx/kg/context_preparations_extraction/fireflies.py
Normal file
191
backend/onyx/kg/context_preparations_extraction/fireflies.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from onyx.configs.kg_configs import KG_IGNORE_EMAIL_DOMAINS
|
||||
from onyx.configs.kg_configs import KG_OWN_COMPANY
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_kg_entity_by_document
|
||||
from onyx.kg.context_preparations_extraction.models import ContextPreparation
|
||||
from onyx.kg.context_preparations_extraction.models import (
|
||||
KGDocumentClassificationPrompt,
|
||||
)
|
||||
from onyx.kg.models import KGChunkFormat
|
||||
from onyx.kg.models import KGClassificationContent
|
||||
from onyx.kg.utils.formatting_utils import generalize_entities
|
||||
from onyx.kg.utils.formatting_utils import kg_email_processing
|
||||
from onyx.prompts.kg_prompts import FIREFLIES_CHUNK_PREPROCESSING_PROMPT
|
||||
from onyx.prompts.kg_prompts import FIREFLIES_DOCUMENT_CLASSIFICATION_PROMPT
|
||||
|
||||
|
||||
def prepare_llm_content_fireflies(chunk: KGChunkFormat) -> ContextPreparation:
|
||||
"""
|
||||
Fireflies - prepare the content for the LLM.
|
||||
"""
|
||||
|
||||
document_id = chunk.document_id
|
||||
primary_owners = chunk.primary_owners
|
||||
secondary_owners = chunk.secondary_owners
|
||||
content = chunk.content
|
||||
chunk.title.capitalize()
|
||||
|
||||
implied_entities = set()
|
||||
implied_relationships = set()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
core_document = get_kg_entity_by_document(db_session, document_id)
|
||||
|
||||
if core_document:
|
||||
core_document_id_name = core_document.id_name
|
||||
else:
|
||||
core_document_id_name = f"FIREFLIES:{document_id}"
|
||||
|
||||
# Do we need this here?
|
||||
implied_entities.add(f"VENDOR:{KG_OWN_COMPANY}")
|
||||
implied_entities.add(f"{core_document_id_name}")
|
||||
implied_entities.add("FIREFLIES:*")
|
||||
implied_relationships.add(
|
||||
f"VENDOR:{KG_OWN_COMPANY}__relates_neutrally_to__{core_document_id_name}"
|
||||
)
|
||||
company_participant_emails = set()
|
||||
account_participant_emails = set()
|
||||
|
||||
for owner in primary_owners + secondary_owners:
|
||||
assert isinstance(KG_IGNORE_EMAIL_DOMAINS, list)
|
||||
|
||||
kg_owner = kg_email_processing(owner)
|
||||
if any(
|
||||
domain.lower() in kg_owner.company.lower()
|
||||
for domain in KG_IGNORE_EMAIL_DOMAINS
|
||||
):
|
||||
continue
|
||||
|
||||
if kg_owner.employee:
|
||||
company_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
|
||||
if kg_owner.name not in implied_entities:
|
||||
generalized_target_entity = list(
|
||||
generalize_entities([core_document_id_name])
|
||||
)[0]
|
||||
|
||||
implied_entities.add(f"EMPLOYEE:{kg_owner.name}")
|
||||
implied_relationships.add(
|
||||
f"EMPLOYEE:{kg_owner.name}__relates_neutrally_to__{core_document_id_name}"
|
||||
)
|
||||
implied_relationships.add(
|
||||
f"EMPLOYEE:{kg_owner.name}__relates_neutrally_to__{generalized_target_entity}"
|
||||
)
|
||||
implied_relationships.add(
|
||||
f"EMPLOYEE:*__relates_neutrally_to__{core_document_id_name}"
|
||||
)
|
||||
implied_relationships.add(
|
||||
f"EMPLOYEE:*__relates_neutrally_to__{generalized_target_entity}"
|
||||
)
|
||||
if kg_owner.company not in implied_entities:
|
||||
implied_entities.add(f"VENDOR:{kg_owner.company}")
|
||||
implied_relationships.add(
|
||||
f"VENDOR:{kg_owner.company}__relates_neutrally_to__{core_document_id_name}"
|
||||
)
|
||||
implied_relationships.add(
|
||||
f"VENDOR:{kg_owner.company}__relates_neutrally_to__{generalized_target_entity}"
|
||||
)
|
||||
|
||||
else:
|
||||
account_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
|
||||
if kg_owner.company not in implied_entities:
|
||||
implied_entities.add(f"ACCOUNT:{kg_owner.company}")
|
||||
implied_entities.add("ACCOUNT:*")
|
||||
implied_relationships.add(
|
||||
f"ACCOUNT:{kg_owner.company}__relates_neutrally_to__{core_document_id_name}"
|
||||
)
|
||||
implied_relationships.add(
|
||||
f"ACCOUNT:*__relates_neutrally_to__{core_document_id_name}"
|
||||
)
|
||||
|
||||
generalized_target_entity = list(
|
||||
generalize_entities([core_document_id_name])
|
||||
)[0]
|
||||
|
||||
implied_relationships.add(
|
||||
f"ACCOUNT:*__relates_neutrally_to__{generalized_target_entity}"
|
||||
)
|
||||
implied_relationships.add(
|
||||
f"ACCOUNT:{kg_owner.company}__relates_neutrally_to__{generalized_target_entity}"
|
||||
)
|
||||
|
||||
participant_string = "\n - " + "\n - ".join(company_participant_emails)
|
||||
account_participant_string = "\n - " + "\n - ".join(account_participant_emails)
|
||||
|
||||
llm_context = FIREFLIES_CHUNK_PREPROCESSING_PROMPT.format(
|
||||
participant_string=participant_string,
|
||||
account_participant_string=account_participant_string,
|
||||
content=content,
|
||||
)
|
||||
|
||||
return ContextPreparation(
|
||||
llm_context=llm_context,
|
||||
core_entity=core_document_id_name,
|
||||
implied_entities=list(implied_entities),
|
||||
implied_relationships=list(implied_relationships),
|
||||
implied_terms=[],
|
||||
)
|
||||
|
||||
|
||||
def prepare_llm_document_content_fireflies(
|
||||
document_classification_content: KGClassificationContent,
|
||||
category_list: str,
|
||||
category_definition_string: str,
|
||||
) -> KGDocumentClassificationPrompt:
|
||||
"""
|
||||
Fireflies - prepare prompt for the LLM classification.
|
||||
"""
|
||||
|
||||
prompt = FIREFLIES_DOCUMENT_CLASSIFICATION_PROMPT.format(
|
||||
beginning_of_call_content=document_classification_content.classification_content,
|
||||
category_list=category_list,
|
||||
category_options=category_definition_string,
|
||||
)
|
||||
|
||||
return KGDocumentClassificationPrompt(
|
||||
llm_prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
def get_classification_content_from_fireflies_chunks(
|
||||
first_num_classification_chunks: list[dict],
|
||||
) -> str:
|
||||
"""
|
||||
Creates a KGClassificationContent object from a list of Fireflies chunks.
|
||||
"""
|
||||
|
||||
assert isinstance(KG_IGNORE_EMAIL_DOMAINS, list)
|
||||
|
||||
primary_owners = first_num_classification_chunks[0]["fields"]["primary_owners"]
|
||||
secondary_owners = first_num_classification_chunks[0]["fields"]["secondary_owners"]
|
||||
|
||||
company_participant_emails = set()
|
||||
account_participant_emails = set()
|
||||
|
||||
for owner in primary_owners + secondary_owners:
|
||||
kg_owner = kg_email_processing(owner)
|
||||
if any(
|
||||
domain.lower() in kg_owner.company.lower()
|
||||
for domain in KG_IGNORE_EMAIL_DOMAINS
|
||||
):
|
||||
continue
|
||||
|
||||
if kg_owner.employee:
|
||||
company_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
|
||||
else:
|
||||
account_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
|
||||
|
||||
participant_string = "\n - " + "\n - ".join(company_participant_emails)
|
||||
account_participant_string = "\n - " + "\n - ".join(account_participant_emails)
|
||||
|
||||
title_string = first_num_classification_chunks[0]["fields"]["title"]
|
||||
content_string = "\n".join(
|
||||
[
|
||||
chunk_content["fields"]["content"]
|
||||
for chunk_content in first_num_classification_chunks
|
||||
]
|
||||
)
|
||||
|
||||
classification_content = f"{title_string}\n\nVendor Participants:\n{participant_string}\n\n\
|
||||
Other Participants:\n{account_participant_string}\n\nBeginning of Call:\n{content_string}"
|
||||
|
||||
return classification_content
|
||||
21
backend/onyx/kg/context_preparations_extraction/models.py
Normal file
21
backend/onyx/kg/context_preparations_extraction/models.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ContextPreparation(BaseModel):
|
||||
"""
|
||||
Context preparation format for the LLM KG extraction.
|
||||
"""
|
||||
|
||||
llm_context: str
|
||||
core_entity: str
|
||||
implied_entities: list[str]
|
||||
implied_relationships: list[str]
|
||||
implied_terms: list[str]
|
||||
|
||||
|
||||
class KGDocumentClassificationPrompt(BaseModel):
|
||||
"""
|
||||
Document classification prompt format for the LLM KG extraction.
|
||||
"""
|
||||
|
||||
llm_prompt: str | None
|
||||
947
backend/onyx/kg/extractions/extraction_processing.py
Normal file
947
backend/onyx/kg/extractions/extraction_processing.py
Normal file
@@ -0,0 +1,947 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from onyx.db.connector import get_unprocessed_connector_ids
|
||||
from onyx.db.document import get_document_updated_at
|
||||
from onyx.db.document import get_unprocessed_kg_documents_for_connector
|
||||
from onyx.db.document import update_document_kg_info
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import add_entity
|
||||
from onyx.db.entities import get_entity_types
|
||||
from onyx.db.relationships import add_or_increment_relationship
|
||||
from onyx.db.relationships import add_relationship
|
||||
from onyx.db.relationships import add_relationship_type
|
||||
from onyx.document_index.vespa.index import KGUChunkUpdateRequest
|
||||
from onyx.document_index.vespa.index import KGUDocumentUpdateRequest
|
||||
from onyx.document_index.vespa.kg_interactions import update_kg_chunks_vespa_info
|
||||
from onyx.kg.models import ConnectorExtractionStats
|
||||
from onyx.kg.models import KGAggregatedExtractions
|
||||
from onyx.kg.models import KGBatchExtractionStats
|
||||
from onyx.kg.models import KGChunkExtraction
|
||||
from onyx.kg.models import KGChunkFormat
|
||||
from onyx.kg.models import KGChunkId
|
||||
from onyx.kg.models import KGClassificationContent
|
||||
from onyx.kg.models import KGClassificationDecisions
|
||||
from onyx.kg.models import KGClassificationInstructionStrings
|
||||
from onyx.kg.utils.chunk_preprocessing import prepare_llm_content
|
||||
from onyx.kg.utils.chunk_preprocessing import prepare_llm_document_content
|
||||
from onyx.kg.utils.formatting_utils import aggregate_kg_extractions
|
||||
from onyx.kg.utils.formatting_utils import generalize_entities
|
||||
from onyx.kg.utils.formatting_utils import generalize_relationships
|
||||
from onyx.kg.vespa.vespa_interactions import get_document_chunks_for_kg_processing
|
||||
from onyx.kg.vespa.vespa_interactions import (
|
||||
get_document_classification_content_for_kg_processing,
|
||||
)
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.prompts.kg_prompts import MASTER_EXTRACTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_classification_instructions() -> Dict[str, KGClassificationInstructionStrings]:
|
||||
"""
|
||||
Prepare the classification instructions for the given source.
|
||||
"""
|
||||
|
||||
classification_instructions_dict: Dict[str, KGClassificationInstructionStrings] = {}
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entity_types = get_entity_types(db_session, active=None)
|
||||
|
||||
for entity_type in entity_types:
|
||||
grounded_source_name = entity_type.grounded_source_name
|
||||
if grounded_source_name is None:
|
||||
continue
|
||||
classification_class_definitions = entity_type.classification_requirements
|
||||
|
||||
classification_options = ", ".join(classification_class_definitions.keys())
|
||||
|
||||
classification_instructions_dict[
|
||||
grounded_source_name
|
||||
] = KGClassificationInstructionStrings(
|
||||
classification_options=classification_options,
|
||||
classification_class_definitions=classification_class_definitions,
|
||||
)
|
||||
|
||||
return classification_instructions_dict
|
||||
|
||||
|
||||
def get_entity_types_str(active: bool | None = None) -> str:
|
||||
"""
|
||||
Get the entity types from the KGChunkExtraction model.
|
||||
"""
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
active_entity_types = get_entity_types(db_session, active)
|
||||
|
||||
entity_types_list = []
|
||||
for entity_type in active_entity_types:
|
||||
if entity_type.description:
|
||||
entity_description = "\n - Description: " + entity_type.description
|
||||
else:
|
||||
entity_description = ""
|
||||
if entity_type.ge_determine_instructions:
|
||||
allowed_options = "\n - Allowed Options: " + ", ".join(
|
||||
entity_type.ge_determine_instructions
|
||||
)
|
||||
else:
|
||||
allowed_options = ""
|
||||
entity_types_list.append(
|
||||
entity_type.id_name + entity_description + allowed_options
|
||||
)
|
||||
|
||||
return "\n".join(entity_types_list)
|
||||
|
||||
|
||||
def get_relationship_types_str(active: bool | None = None) -> str:
|
||||
"""
|
||||
Get the relationship types from the database.
|
||||
|
||||
Args:
|
||||
active: Filter by active status (True, False, or None for all)
|
||||
|
||||
Returns:
|
||||
A string with all relationship types formatted as "source_type__relationship_type__target_type"
|
||||
"""
|
||||
from onyx.db.relationships import get_all_relationship_types
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
relationship_types = get_all_relationship_types(db_session)
|
||||
|
||||
# Filter by active status if specified
|
||||
if active is not None:
|
||||
relationship_types = [
|
||||
rt for rt in relationship_types if rt.active == active
|
||||
]
|
||||
|
||||
relationship_types_list = []
|
||||
for rel_type in relationship_types:
|
||||
# Format as "source_type__relationship_type__target_type"
|
||||
formatted_type = f"{rel_type.source_entity_type_id_name}__{rel_type.type}__{rel_type.target_entity_type_id_name}"
|
||||
relationship_types_list.append(formatted_type)
|
||||
|
||||
return "\n".join(relationship_types_list)
|
||||
|
||||
|
||||
def kg_extraction_initialization(tenant_id: str, num_chunks: int = 1000) -> None:
|
||||
"""
|
||||
This extraction will create a random sample of chunks to process in order to perform
|
||||
clustering and topic modeling.
|
||||
"""
|
||||
|
||||
logger.info(f"Starting kg extraction for tenant {tenant_id}")
|
||||
|
||||
|
||||
def kg_extraction(
|
||||
tenant_id: str, index_name: str, processing_chunk_batch_size: int = 8
|
||||
) -> list[ConnectorExtractionStats]:
|
||||
"""
|
||||
This extraction will try to extract from all chunks that have not been kg-processed yet.
|
||||
|
||||
Approach:
|
||||
- Get all unprocessed connectors
|
||||
- For each connector:
|
||||
- Get all unprocessed documents
|
||||
- Classify each document to select proper ones
|
||||
- For each document:
|
||||
- Get all chunks
|
||||
- For each chunk:
|
||||
- Extract entities, relationships, and terms
|
||||
- make sure for each entity and relationship also the generalized versions are extracted!
|
||||
- Aggregate results as needed
|
||||
- Update Vespa and postgres
|
||||
"""
|
||||
|
||||
logger.info(f"Starting kg extraction for tenant {tenant_id}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
connector_ids = get_unprocessed_connector_ids(db_session)
|
||||
|
||||
connector_extraction_stats: list[ConnectorExtractionStats] = []
|
||||
document_kg_updates: Dict[str, KGUDocumentUpdateRequest] = {}
|
||||
|
||||
processing_chunks: list[KGChunkFormat] = []
|
||||
carryover_chunks: list[KGChunkFormat] = []
|
||||
connector_aggregated_kg_extractions_list: list[KGAggregatedExtractions] = []
|
||||
|
||||
document_classification_instructions = _get_classification_instructions()
|
||||
|
||||
for connector_id in connector_ids:
|
||||
connector_failed_chunk_extractions: list[KGChunkId] = []
|
||||
connector_succeeded_chunk_extractions: list[KGChunkId] = []
|
||||
connector_aggregated_kg_extractions: KGAggregatedExtractions = (
|
||||
KGAggregatedExtractions(
|
||||
grounded_entities_document_ids=defaultdict(str),
|
||||
entities=defaultdict(int),
|
||||
relationships=defaultdict(int),
|
||||
terms=defaultdict(int),
|
||||
)
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
unprocessed_documents = get_unprocessed_kg_documents_for_connector(
|
||||
db_session,
|
||||
connector_id,
|
||||
)
|
||||
|
||||
# TODO: restricted for testing only
|
||||
unprocessed_documents_list = list(unprocessed_documents)[:6]
|
||||
|
||||
document_classification_content_list = (
|
||||
get_document_classification_content_for_kg_processing(
|
||||
[
|
||||
unprocessed_document.id
|
||||
for unprocessed_document in unprocessed_documents_list
|
||||
],
|
||||
index_name,
|
||||
batch_size=processing_chunk_batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
classification_outcomes: list[tuple[bool, KGClassificationDecisions]] = []
|
||||
|
||||
for document_classification_content in document_classification_content_list:
|
||||
classification_outcomes.extend(
|
||||
_kg_document_classification(
|
||||
document_classification_content,
|
||||
document_classification_instructions,
|
||||
)
|
||||
)
|
||||
|
||||
documents_to_process = []
|
||||
for document_to_process, document_classification_outcome in zip(
|
||||
unprocessed_documents_list, classification_outcomes
|
||||
):
|
||||
if (
|
||||
document_classification_outcome[0]
|
||||
and document_classification_outcome[1].classification_decision
|
||||
):
|
||||
documents_to_process.append(document_to_process)
|
||||
|
||||
for document_to_process in documents_to_process:
|
||||
formatted_chunk_batches = get_document_chunks_for_kg_processing(
|
||||
document_to_process.id,
|
||||
index_name,
|
||||
batch_size=processing_chunk_batch_size,
|
||||
)
|
||||
|
||||
formatted_chunk_batches_list = list(formatted_chunk_batches)
|
||||
|
||||
for formatted_chunk_batch in formatted_chunk_batches_list:
|
||||
processing_chunks.extend(formatted_chunk_batch)
|
||||
|
||||
if len(processing_chunks) >= processing_chunk_batch_size:
|
||||
carryover_chunks.extend(
|
||||
processing_chunks[processing_chunk_batch_size:]
|
||||
)
|
||||
processing_chunks = processing_chunks[:processing_chunk_batch_size]
|
||||
|
||||
chunk_processing_batch_results = _kg_chunk_batch_extraction(
|
||||
processing_chunks, index_name, tenant_id
|
||||
)
|
||||
|
||||
# Consider removing the stats expressions here and rather write to the db(?)
|
||||
connector_failed_chunk_extractions.extend(
|
||||
chunk_processing_batch_results.failed
|
||||
)
|
||||
connector_succeeded_chunk_extractions.extend(
|
||||
chunk_processing_batch_results.succeeded
|
||||
)
|
||||
|
||||
aggregated_batch_extractions = (
|
||||
chunk_processing_batch_results.aggregated_kg_extractions
|
||||
)
|
||||
# Update grounded_entities_document_ids (replace values)
|
||||
connector_aggregated_kg_extractions.grounded_entities_document_ids.update(
|
||||
aggregated_batch_extractions.grounded_entities_document_ids
|
||||
)
|
||||
# Add to entity counts instead of replacing
|
||||
for entity, count in aggregated_batch_extractions.entities.items():
|
||||
connector_aggregated_kg_extractions.entities[entity] += count
|
||||
# Add to relationship counts instead of replacing
|
||||
for (
|
||||
relationship,
|
||||
count,
|
||||
) in aggregated_batch_extractions.relationships.items():
|
||||
connector_aggregated_kg_extractions.relationships[
|
||||
relationship
|
||||
] += count
|
||||
# Add to term counts instead of replacing
|
||||
for term, count in aggregated_batch_extractions.terms.items():
|
||||
connector_aggregated_kg_extractions.terms[term] += count
|
||||
|
||||
connector_extraction_stats.append(
|
||||
ConnectorExtractionStats(
|
||||
connector_id=connector_id,
|
||||
num_failed=len(connector_failed_chunk_extractions),
|
||||
num_succeeded=len(connector_succeeded_chunk_extractions),
|
||||
num_processed=len(processing_chunks),
|
||||
)
|
||||
)
|
||||
|
||||
processing_chunks = carryover_chunks.copy()
|
||||
carryover_chunks = []
|
||||
|
||||
# processes remaining chunks
|
||||
chunk_processing_batch_results = _kg_chunk_batch_extraction(
|
||||
processing_chunks, index_name, tenant_id
|
||||
)
|
||||
|
||||
# Consider removing the stats expressions here and rather write to the db(?)
|
||||
connector_failed_chunk_extractions.extend(chunk_processing_batch_results.failed)
|
||||
connector_succeeded_chunk_extractions.extend(
|
||||
chunk_processing_batch_results.succeeded
|
||||
)
|
||||
|
||||
aggregated_batch_extractions = (
|
||||
chunk_processing_batch_results.aggregated_kg_extractions
|
||||
)
|
||||
# Update grounded_entities_document_ids (replace values)
|
||||
connector_aggregated_kg_extractions.grounded_entities_document_ids.update(
|
||||
aggregated_batch_extractions.grounded_entities_document_ids
|
||||
)
|
||||
# Add to entity counts instead of replacing
|
||||
for entity, count in aggregated_batch_extractions.entities.items():
|
||||
connector_aggregated_kg_extractions.entities[entity] += count
|
||||
# Add to relationship counts instead of replacing
|
||||
for relationship, count in aggregated_batch_extractions.relationships.items():
|
||||
connector_aggregated_kg_extractions.relationships[relationship] += count
|
||||
# Add to term counts instead of replacing
|
||||
for term, count in aggregated_batch_extractions.terms.items():
|
||||
connector_aggregated_kg_extractions.terms[term] += count
|
||||
|
||||
connector_extraction_stats.append(
|
||||
ConnectorExtractionStats(
|
||||
connector_id=connector_id,
|
||||
num_failed=len(connector_failed_chunk_extractions),
|
||||
num_succeeded=len(connector_succeeded_chunk_extractions),
|
||||
num_processed=len(processing_chunks),
|
||||
)
|
||||
)
|
||||
|
||||
processing_chunks = []
|
||||
carryover_chunks = []
|
||||
|
||||
connector_aggregated_kg_extractions_list.append(
|
||||
connector_aggregated_kg_extractions
|
||||
)
|
||||
|
||||
# aggregate document updates
|
||||
|
||||
for (
|
||||
processed_document
|
||||
) in (
|
||||
unprocessed_documents_list
|
||||
): # This will need to change if we do not materialize docs
|
||||
document_kg_updates[processed_document.id] = KGUDocumentUpdateRequest(
|
||||
document_id=processed_document.id,
|
||||
entities=set(),
|
||||
relationships=set(),
|
||||
terms=set(),
|
||||
)
|
||||
|
||||
updated_chunk_batches = get_document_chunks_for_kg_processing(
|
||||
processed_document.id,
|
||||
index_name,
|
||||
batch_size=processing_chunk_batch_size,
|
||||
)
|
||||
|
||||
for updated_chunk_batch in updated_chunk_batches:
|
||||
for updated_chunk in updated_chunk_batch:
|
||||
chunk_entities = updated_chunk.entities
|
||||
chunk_relationships = updated_chunk.relationships
|
||||
chunk_terms = updated_chunk.terms
|
||||
document_kg_updates[processed_document.id].entities.update(
|
||||
chunk_entities
|
||||
)
|
||||
document_kg_updates[processed_document.id].relationships.update(
|
||||
chunk_relationships
|
||||
)
|
||||
document_kg_updates[processed_document.id].terms.update(chunk_terms)
|
||||
|
||||
aggregated_kg_extractions = aggregate_kg_extractions(
|
||||
connector_aggregated_kg_extractions_list
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
tracked_entity_types = [
|
||||
x.id_name for x in get_entity_types(db_session, active=None)
|
||||
]
|
||||
|
||||
# Populate the KG database with the extracted entities, relationships, and terms
|
||||
for (
|
||||
entity,
|
||||
extraction_count,
|
||||
) in aggregated_kg_extractions.entities.items():
|
||||
if len(entity.split(":")) != 2:
|
||||
logger.error(
|
||||
f"Invalid entity {entity} in aggregated_kg_extractions.entities"
|
||||
)
|
||||
continue
|
||||
|
||||
entity_type, entity_name = entity.split(":")
|
||||
entity_type = entity_type.upper()
|
||||
entity_name = entity_name.capitalize()
|
||||
|
||||
if entity_type not in tracked_entity_types:
|
||||
continue
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if (
|
||||
entity
|
||||
not in aggregated_kg_extractions.grounded_entities_document_ids
|
||||
):
|
||||
add_entity(
|
||||
db_session=db_session,
|
||||
entity_type=entity_type,
|
||||
name=entity_name,
|
||||
cluster_count=extraction_count,
|
||||
)
|
||||
else:
|
||||
event_time = get_document_updated_at(
|
||||
entity,
|
||||
db_session,
|
||||
)
|
||||
add_entity(
|
||||
db_session=db_session,
|
||||
entity_type=entity_type,
|
||||
name=entity_name,
|
||||
cluster_count=extraction_count,
|
||||
document_id=aggregated_kg_extractions.grounded_entities_document_ids[
|
||||
entity
|
||||
],
|
||||
event_time=event_time,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding entity {entity} to the database: {e}")
|
||||
|
||||
relationship_type_counter: dict[str, int] = defaultdict(int)
|
||||
|
||||
for (
|
||||
relationship,
|
||||
extraction_count,
|
||||
) in aggregated_kg_extractions.relationships.items():
|
||||
relationship_split = relationship.split("__")
|
||||
|
||||
source_entity, relationship_type_, target_entity = relationship.split("__")
|
||||
source_entity = relationship_split[0]
|
||||
relationship_type = " ".join(relationship_split[1:-1]).replace("__", "_")
|
||||
target_entity = relationship_split[-1]
|
||||
|
||||
source_entity_type = source_entity.split(":")[0]
|
||||
target_entity_type = target_entity.split(":")[0]
|
||||
|
||||
if (
|
||||
source_entity_type not in tracked_entity_types
|
||||
or target_entity_type not in tracked_entity_types
|
||||
):
|
||||
continue
|
||||
|
||||
source_entity_general = f"{source_entity_type.upper()}"
|
||||
target_entity_general = f"{target_entity_type.upper()}"
|
||||
relationship_type_id_name = (
|
||||
f"{source_entity_general}__{relationship_type.lower()}__"
|
||||
f"{target_entity_general}"
|
||||
)
|
||||
|
||||
relationship_type_counter[relationship_type_id_name] += extraction_count
|
||||
|
||||
for (
|
||||
relationship_type_id_name,
|
||||
extraction_count,
|
||||
) in relationship_type_counter.items():
|
||||
(
|
||||
source_entity_type,
|
||||
relationship_type,
|
||||
target_entity_type,
|
||||
) = relationship_type_id_name.split("__")
|
||||
|
||||
if (
|
||||
source_entity_type not in tracked_entity_types
|
||||
or target_entity_type not in tracked_entity_types
|
||||
):
|
||||
continue
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
add_relationship_type(
|
||||
db_session=db_session,
|
||||
source_entity_type=source_entity_type.upper(),
|
||||
relationship_type=relationship_type,
|
||||
target_entity_type=target_entity_type.upper(),
|
||||
definition=False,
|
||||
extraction_count=extraction_count,
|
||||
)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error adding relationship type {relationship_type_id_name} to the database: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error adding relationship type {relationship_type_id_name} to the database: {e}"
|
||||
)
|
||||
|
||||
for (
|
||||
relationship,
|
||||
extraction_count,
|
||||
) in aggregated_kg_extractions.relationships.items():
|
||||
relationship_split = relationship.split("__")
|
||||
|
||||
source_entity, relationship_type_, target_entity = relationship.split("__")
|
||||
source_entity = relationship_split[0]
|
||||
relationship_type = (
|
||||
" ".join(relationship_split[1:-1]).replace("__", " ").replace("_", " ")
|
||||
)
|
||||
target_entity = relationship_split[-1]
|
||||
|
||||
source_entity_type = source_entity.split(":")[0]
|
||||
target_entity_type = target_entity.split(":")[0]
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
add_relationship(db_session, relationship, extraction_count)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error adding relationship {relationship} to the database: {e}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
add_or_increment_relationship(db_session, relationship)
|
||||
db_session.commit()
|
||||
|
||||
# Populate the Documents table with the kg information for the documents
|
||||
|
||||
for document_id, document_kg_update in document_kg_updates.items():
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
update_document_kg_info(
|
||||
db_session,
|
||||
document_id,
|
||||
kg_processed=True,
|
||||
kg_data={
|
||||
"entities": list(document_kg_update.entities),
|
||||
"relationships": list(document_kg_update.relationships),
|
||||
"terms": list(document_kg_update.terms),
|
||||
},
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
return connector_extraction_stats
|
||||
|
||||
|
||||
def _kg_chunk_batch_extraction(
|
||||
chunks: list[KGChunkFormat],
|
||||
index_name: str,
|
||||
tenant_id: str,
|
||||
) -> KGBatchExtractionStats:
|
||||
_, fast_llm = get_default_llms()
|
||||
|
||||
succeeded_chunk_id: list[KGChunkId] = []
|
||||
failed_chunk_id: list[KGChunkId] = []
|
||||
succeeded_chunk_extraction: list[KGChunkExtraction] = []
|
||||
|
||||
preformatted_prompt = MASTER_EXTRACTION_PROMPT.format(
|
||||
entity_types=get_entity_types_str(active=True)
|
||||
)
|
||||
|
||||
def process_single_chunk(
|
||||
chunk: KGChunkFormat, preformatted_prompt: str
|
||||
) -> tuple[bool, KGUChunkUpdateRequest]:
|
||||
"""Process a single chunk and return success status and chunk ID."""
|
||||
|
||||
# For now, we're just processing the content
|
||||
# TODO: Implement actual prompt application logic
|
||||
|
||||
llm_preprocessing = prepare_llm_content(chunk)
|
||||
|
||||
formatted_prompt = preformatted_prompt.replace(
|
||||
"---content---", llm_preprocessing.llm_context
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=formatted_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"LLM Extraction from chunk {chunk.chunk_id} from doc {chunk.document_id}"
|
||||
)
|
||||
raw_extraction_result = fast_llm.invoke(msg)
|
||||
extraction_result = message_to_string(raw_extraction_result)
|
||||
|
||||
try:
|
||||
cleaned_result = (
|
||||
extraction_result.replace("```json", "").replace("```", "").strip()
|
||||
)
|
||||
parsed_result = json.loads(cleaned_result)
|
||||
extracted_entities = parsed_result.get("entities", [])
|
||||
extracted_relationships = parsed_result.get("relationships", [])
|
||||
extracted_terms = parsed_result.get("terms", [])
|
||||
|
||||
implied_extracted_relationships = [
|
||||
llm_preprocessing.core_entity
|
||||
+ "__"
|
||||
+ "relates_neutrally_to"
|
||||
+ "__"
|
||||
+ entity
|
||||
for entity in extracted_entities
|
||||
]
|
||||
|
||||
all_entities = set(
|
||||
list(extracted_entities)
|
||||
+ list(llm_preprocessing.implied_entities)
|
||||
+ list(
|
||||
generalize_entities(
|
||||
extracted_entities + llm_preprocessing.implied_entities
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"All entities: {all_entities}")
|
||||
|
||||
all_relationships = (
|
||||
extracted_relationships
|
||||
+ llm_preprocessing.implied_relationships
|
||||
+ implied_extracted_relationships
|
||||
)
|
||||
all_relationships = set(
|
||||
list(all_relationships)
|
||||
+ list(generalize_relationships(all_relationships))
|
||||
)
|
||||
|
||||
kg_updates = [
|
||||
KGUChunkUpdateRequest(
|
||||
document_id=chunk.document_id,
|
||||
chunk_id=chunk.chunk_id,
|
||||
core_entity=llm_preprocessing.core_entity,
|
||||
entities=all_entities,
|
||||
relationships=all_relationships,
|
||||
terms=set(extracted_terms),
|
||||
),
|
||||
]
|
||||
|
||||
update_kg_chunks_vespa_info(
|
||||
kg_update_requests=kg_updates,
|
||||
index_name=index_name,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"KG updated: {chunk.chunk_id} from doc {chunk.document_id}"
|
||||
)
|
||||
|
||||
return True, kg_updates[0] # only single chunk
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
f"Invalid JSON format for extraction of chunk {chunk.chunk_id} \
|
||||
from doc {chunk.document_id}: {str(e)}"
|
||||
)
|
||||
logger.error(f"Raw output: {extraction_result}")
|
||||
return False, KGUChunkUpdateRequest(
|
||||
document_id=chunk.document_id,
|
||||
chunk_id=chunk.chunk_id,
|
||||
core_entity=llm_preprocessing.core_entity,
|
||||
entities=set(),
|
||||
relationships=set(),
|
||||
terms=set(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process chunk {chunk.chunk_id} from doc {chunk.document_id}: {str(e)}"
|
||||
)
|
||||
return False, KGUChunkUpdateRequest(
|
||||
document_id=chunk.document_id,
|
||||
chunk_id=chunk.chunk_id,
|
||||
core_entity="",
|
||||
entities=set(),
|
||||
relationships=set(),
|
||||
terms=set(),
|
||||
)
|
||||
|
||||
# Assume for prototype: use_threads = True. TODO: Make thread safe!
|
||||
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(process_single_chunk, (chunk, preformatted_prompt)) for chunk in chunks
|
||||
]
|
||||
|
||||
logger.debug("Running KG extraction on chunks in parallel")
|
||||
results = run_functions_tuples_in_parallel(functions_with_args, allow_failures=True)
|
||||
|
||||
# Sort results into succeeded and failed
|
||||
for success, chunk_results in results:
|
||||
if success:
|
||||
succeeded_chunk_id.append(
|
||||
KGChunkId(
|
||||
document_id=chunk_results.document_id,
|
||||
chunk_id=chunk_results.chunk_id,
|
||||
)
|
||||
)
|
||||
succeeded_chunk_extraction.append(chunk_results)
|
||||
else:
|
||||
failed_chunk_id.append(
|
||||
KGChunkId(
|
||||
document_id=chunk_results.document_id,
|
||||
chunk_id=chunk_results.chunk_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Collect data for postgres later on
|
||||
|
||||
aggregated_kg_extractions = KGAggregatedExtractions(
|
||||
grounded_entities_document_ids=defaultdict(str),
|
||||
entities=defaultdict(int),
|
||||
relationships=defaultdict(int),
|
||||
terms=defaultdict(int),
|
||||
)
|
||||
|
||||
for chunk_result in succeeded_chunk_extraction:
|
||||
aggregated_kg_extractions.grounded_entities_document_ids[
|
||||
chunk_result.core_entity
|
||||
] = chunk_result.document_id
|
||||
|
||||
mentioned_chunk_entities: set[str] = set()
|
||||
for relationship in chunk_result.relationships:
|
||||
relationship_split = relationship.split("__")
|
||||
if len(relationship_split) == 3:
|
||||
if relationship_split[0] not in mentioned_chunk_entities:
|
||||
aggregated_kg_extractions.entities[relationship_split[0]] += 1
|
||||
mentioned_chunk_entities.add(relationship_split[0])
|
||||
if relationship_split[2] not in mentioned_chunk_entities:
|
||||
aggregated_kg_extractions.entities[relationship_split[2]] += 1
|
||||
mentioned_chunk_entities.add(relationship_split[2])
|
||||
aggregated_kg_extractions.relationships[relationship] += 1
|
||||
|
||||
for kg_entity in chunk_result.entities:
|
||||
if kg_entity not in mentioned_chunk_entities:
|
||||
aggregated_kg_extractions.entities[kg_entity] += 1
|
||||
mentioned_chunk_entities.add(kg_entity)
|
||||
|
||||
for kg_term in chunk_result.terms:
|
||||
aggregated_kg_extractions.terms[kg_term] += 1
|
||||
|
||||
return KGBatchExtractionStats(
|
||||
connector_id=chunks[0].connector_id if chunks else None, # TODO: Update!
|
||||
succeeded=succeeded_chunk_id,
|
||||
failed=failed_chunk_id,
|
||||
aggregated_kg_extractions=aggregated_kg_extractions,
|
||||
)
|
||||
|
||||
|
||||
def _kg_document_classification(
|
||||
document_classification_content_list: list[KGClassificationContent],
|
||||
classification_instructions: dict[str, KGClassificationInstructionStrings],
|
||||
) -> list[tuple[bool, KGClassificationDecisions]]:
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
|
||||
def classify_single_document(
|
||||
document_classification_content: KGClassificationContent,
|
||||
classification_instructions: dict[str, KGClassificationInstructionStrings],
|
||||
) -> tuple[bool, KGClassificationDecisions]:
|
||||
"""Classify a single document whether it should be kg-processed or not"""
|
||||
|
||||
source = document_classification_content.source_type
|
||||
document_id = document_classification_content.document_id
|
||||
|
||||
if source not in classification_instructions:
|
||||
logger.info(
|
||||
f"Source {source} did not have kg classification instructions. No content analysis."
|
||||
)
|
||||
return False, KGClassificationDecisions(
|
||||
document_id=document_id,
|
||||
classification_decision=False,
|
||||
classification_class=None,
|
||||
)
|
||||
|
||||
classification_prompt = prepare_llm_document_content(
|
||||
document_classification_content,
|
||||
category_list=classification_instructions[source].classification_options,
|
||||
category_definitions=classification_instructions[
|
||||
source
|
||||
].classification_class_definitions,
|
||||
)
|
||||
|
||||
if classification_prompt.llm_prompt is None:
|
||||
logger.info(
|
||||
f"Source {source} did not have kg document classification instructions. No content analysis."
|
||||
)
|
||||
return False, KGClassificationDecisions(
|
||||
document_id=document_id,
|
||||
classification_decision=False,
|
||||
classification_class=None,
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=classification_prompt.llm_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"LLM Classification from document {document_classification_content.document_id}"
|
||||
)
|
||||
raw_classification_result = primary_llm.invoke(msg)
|
||||
classification_result = (
|
||||
message_to_string(raw_classification_result)
|
||||
.replace("```json", "")
|
||||
.replace("```", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
classification_class = classification_result.split("CATEGORY:")[1].strip()
|
||||
|
||||
if (
|
||||
classification_class
|
||||
in classification_instructions[source].classification_class_definitions
|
||||
):
|
||||
extraction_decision = cast(
|
||||
bool,
|
||||
classification_instructions[
|
||||
source
|
||||
].classification_class_definitions[classification_class][
|
||||
"extraction"
|
||||
],
|
||||
)
|
||||
else:
|
||||
extraction_decision = False
|
||||
|
||||
return True, KGClassificationDecisions(
|
||||
document_id=document_id,
|
||||
classification_decision=extraction_decision,
|
||||
classification_class=classification_class,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to classify document {document_classification_content.document_id}: {str(e)}"
|
||||
)
|
||||
return False, KGClassificationDecisions(
|
||||
document_id=document_id,
|
||||
classification_decision=False,
|
||||
classification_class=None,
|
||||
)
|
||||
|
||||
# Assume for prototype: use_threads = True. TODO: Make thread safe!
|
||||
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(
|
||||
classify_single_document,
|
||||
(document_classification_content, classification_instructions),
|
||||
)
|
||||
for document_classification_content in document_classification_content_list
|
||||
]
|
||||
|
||||
logger.debug("Running KG classification on documents in parallel")
|
||||
results = run_functions_tuples_in_parallel(functions_with_args, allow_failures=True)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# logger.debug("Running KG extraction on chunks in parallel")
|
||||
# results = run_functions_tuples_in_parallel(functions_with_args, allow_failures=True)
|
||||
|
||||
# # Sort results into succeeded and failed
|
||||
# for success, chunk_results in results:
|
||||
# if success:
|
||||
# succeeded_chunk_id.append(
|
||||
# KGChunkId(
|
||||
# document_id=chunk_results.document_id,
|
||||
# chunk_id=chunk_results.chunk_id,
|
||||
# )
|
||||
# )
|
||||
# succeeded_chunk_extraction.append(chunk_results)
|
||||
# else:
|
||||
# failed_chunk_id.append(
|
||||
# KGChunkId(
|
||||
# document_id=chunk_results.document_id,
|
||||
# chunk_id=chunk_results.chunk_id,
|
||||
# )
|
||||
# )
|
||||
|
||||
# # Collect data for postgres later on
|
||||
|
||||
# aggregated_kg_extractions = KGAggregatedExtractions(
|
||||
# grounded_entities_document_ids=defaultdict(str),
|
||||
# entities=defaultdict(int),
|
||||
# relationships=defaultdict(int),
|
||||
# terms=defaultdict(int),
|
||||
# )
|
||||
|
||||
# for chunk_result in succeeded_chunk_extraction:
|
||||
# aggregated_kg_extractions.grounded_entities_document_ids[
|
||||
# chunk_result.core_entity
|
||||
# ] = chunk_result.document_id
|
||||
|
||||
# mentioned_chunk_entities: set[str] = set()
|
||||
# for relationship in chunk_result.relationships:
|
||||
# relationship_split = relationship.split("__")
|
||||
# if len(relationship_split) == 3:
|
||||
# if relationship_split[0] not in mentioned_chunk_entities:
|
||||
# aggregated_kg_extractions.entities[relationship_split[0]] += 1
|
||||
# mentioned_chunk_entities.add(relationship_split[0])
|
||||
# if relationship_split[2] not in mentioned_chunk_entities:
|
||||
# aggregated_kg_extractions.entities[relationship_split[2]] += 1
|
||||
# mentioned_chunk_entities.add(relationship_split[2])
|
||||
# aggregated_kg_extractions.relationships[relationship] += 1
|
||||
|
||||
# for kg_entity in chunk_result.entities:
|
||||
# if kg_entity not in mentioned_chunk_entities:
|
||||
# aggregated_kg_extractions.entities[kg_entity] += 1
|
||||
# mentioned_chunk_entities.add(kg_entity)
|
||||
|
||||
# for kg_term in chunk_result.terms:
|
||||
# aggregated_kg_extractions.terms[kg_term] += 1
|
||||
|
||||
# return KGBatchExtractionStats(
|
||||
# connector_id=chunks[0].connector_id if chunks else None, # TODO: Update!
|
||||
# succeeded=succeeded_chunk_id,
|
||||
# failed=failed_chunk_id,
|
||||
# aggregated_kg_extractions=aggregated_kg_extractions,
|
||||
# )
|
||||
|
||||
|
||||
def _kg_connector_extraction(
|
||||
connector_id: str,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
logger.info(
|
||||
f"Starting kg extraction for connector {connector_id} for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
# - grab kg type data from postgres
|
||||
|
||||
# - construct prompt
|
||||
|
||||
# find all documents for the connector that have not been kg-processed
|
||||
|
||||
# - loop for :
|
||||
|
||||
# - grab a number of chunks from vespa
|
||||
|
||||
# - convert them into the KGChunk format
|
||||
|
||||
# - run the extractions in parallel
|
||||
|
||||
# - save the results
|
||||
|
||||
# - mark chunks as processed
|
||||
|
||||
# - update the connector status
|
||||
|
||||
|
||||
#
|
||||
86
backend/onyx/kg/extractions/llm_extraction.py
Normal file
86
backend/onyx/kg/extractions/llm_extraction.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
"""
|
||||
def kg_extraction_section(
|
||||
kg_chunks: list[KGChunk],
|
||||
llm: LLM,
|
||||
max_workers: int = 10,
|
||||
) -> list[KGChunkExtraction]:
|
||||
def _get_metadata_str(metadata: dict[str, str | list[str]]) -> str:
|
||||
metadata_str = "\nMetadata:\n"
|
||||
for key, value in metadata.items():
|
||||
value_str = ", ".join(value) if isinstance(value, list) else value
|
||||
metadata_str += f"{key} - {value_str}\n"
|
||||
return metadata_str
|
||||
|
||||
def _get_usefulness_messages() -> list[dict[str, str]]:
|
||||
metadata_str = _get_metadata_str(metadata) if metadata else ""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SECTION_FILTER_PROMPT.format(
|
||||
title=title.replace("\n", " "),
|
||||
chunk_text=section_content,
|
||||
user_query=query,
|
||||
optional_metadata=metadata_str,
|
||||
),
|
||||
},
|
||||
]
|
||||
return messages
|
||||
|
||||
def _extract_usefulness(model_output: str) -> bool:
|
||||
if model_output.strip().strip('"').lower() == NONUSEFUL_PAT.lower():
|
||||
return False
|
||||
return True
|
||||
|
||||
messages = _get_usefulness_messages()
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = message_to_string(llm.invoke(filled_llm_prompt))
|
||||
logger.debug(model_output)
|
||||
|
||||
return _extract_usefulness(model_output)
|
||||
|
||||
"""
|
||||
"""
|
||||
def llm_batch_eval_sections(
|
||||
query: str,
|
||||
section_contents: list[str],
|
||||
llm: LLM,
|
||||
titles: list[str],
|
||||
metadata_list: list[dict[str, str | list[str]]],
|
||||
use_threads: bool = True,
|
||||
) -> list[bool]:
|
||||
if DISABLE_LLM_DOC_RELEVANCE:
|
||||
raise RuntimeError(
|
||||
"LLM Doc Relevance is globally disabled, "
|
||||
"this should have been caught upstream."
|
||||
)
|
||||
|
||||
if use_threads:
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(llm_eval_section, (query, section_content, llm, title, metadata))
|
||||
for section_content, title, metadata in zip(
|
||||
section_contents, titles, metadata_list
|
||||
)
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
"Running LLM usefulness eval in parallel (following logging may be out of order)"
|
||||
)
|
||||
parallel_results = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=True
|
||||
)
|
||||
|
||||
# In case of failure/timeout, don't throw out the section
|
||||
return [True if item is None else item for item in parallel_results]
|
||||
|
||||
else:
|
||||
return [
|
||||
llm_eval_section(query, section_content, llm, title, metadata)
|
||||
for section_content, title, metadata in zip(
|
||||
section_contents, titles, metadata_list
|
||||
)
|
||||
]
|
||||
"""
|
||||
99
backend/onyx/kg/models.py
Normal file
99
backend/onyx/kg/models.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class KGChunkFormat(BaseModel):
|
||||
connector_id: int | None = None
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
title: str
|
||||
content: str
|
||||
primary_owners: list[str]
|
||||
secondary_owners: list[str]
|
||||
source_type: str
|
||||
metadata: dict[str, str | list[str]] | None = None
|
||||
entities: Dict[str, int] = {}
|
||||
relationships: Dict[str, int] = {}
|
||||
terms: Dict[str, int] = {}
|
||||
|
||||
|
||||
class KGChunkExtraction(BaseModel):
|
||||
connector_id: int
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
core_entity: str
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
terms: list[str]
|
||||
|
||||
|
||||
class KGChunkId(BaseModel):
|
||||
connector_id: int | None = None
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
|
||||
|
||||
class KGAggregatedExtractions(BaseModel):
|
||||
grounded_entities_document_ids: defaultdict[str, str]
|
||||
entities: defaultdict[str, int]
|
||||
relationships: defaultdict[str, int]
|
||||
terms: defaultdict[str, int]
|
||||
|
||||
|
||||
class KGBatchExtractionStats(BaseModel):
|
||||
connector_id: int | None = None
|
||||
succeeded: list[KGChunkId]
|
||||
failed: list[KGChunkId]
|
||||
aggregated_kg_extractions: KGAggregatedExtractions
|
||||
|
||||
|
||||
class ConnectorExtractionStats(BaseModel):
|
||||
connector_id: int
|
||||
num_succeeded: int
|
||||
num_failed: int
|
||||
num_processed: int
|
||||
|
||||
|
||||
class KGPerson(BaseModel):
|
||||
name: str
|
||||
company: str
|
||||
employee: bool
|
||||
|
||||
|
||||
class NormalizedEntities(BaseModel):
|
||||
entities: list[str]
|
||||
entity_normalization_map: dict[str, str | None]
|
||||
|
||||
|
||||
class NormalizedRelationships(BaseModel):
|
||||
relationships: list[str]
|
||||
relationship_normalization_map: dict[str, str | None]
|
||||
|
||||
|
||||
class NormalizedTerms(BaseModel):
|
||||
terms: list[str]
|
||||
term_normalization_map: dict[str, str | None]
|
||||
|
||||
|
||||
class KGClassificationContent(BaseModel):
|
||||
document_id: str
|
||||
classification_content: str
|
||||
source_type: str
|
||||
|
||||
|
||||
class KGClassificationDecisions(BaseModel):
|
||||
document_id: str
|
||||
classification_decision: bool
|
||||
classification_class: str | None
|
||||
|
||||
|
||||
class KGClassificationRule(BaseModel):
|
||||
description: str
|
||||
extration: bool
|
||||
|
||||
|
||||
class KGClassificationInstructionStrings(BaseModel):
|
||||
classification_options: str
|
||||
classification_class_definitions: dict[str, Dict[str, str | bool]]
|
||||
55
backend/onyx/kg/utils/chunk_preprocessing.py
Normal file
55
backend/onyx/kg/utils/chunk_preprocessing.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import Dict
|
||||
|
||||
from onyx.kg.context_preparations_extraction.fireflies import (
|
||||
prepare_llm_content_fireflies,
|
||||
)
|
||||
from onyx.kg.context_preparations_extraction.fireflies import (
|
||||
prepare_llm_document_content_fireflies,
|
||||
)
|
||||
from onyx.kg.context_preparations_extraction.models import ContextPreparation
|
||||
from onyx.kg.context_preparations_extraction.models import (
|
||||
KGDocumentClassificationPrompt,
|
||||
)
|
||||
from onyx.kg.models import KGChunkFormat
|
||||
from onyx.kg.models import KGClassificationContent
|
||||
|
||||
|
||||
def prepare_llm_content(chunk: KGChunkFormat) -> ContextPreparation:
|
||||
"""
|
||||
Prepare the content for the LLM.
|
||||
"""
|
||||
if chunk.source_type == "fireflies":
|
||||
return prepare_llm_content_fireflies(chunk)
|
||||
|
||||
else:
|
||||
return ContextPreparation(
|
||||
llm_context=chunk.content,
|
||||
core_entity="",
|
||||
implied_entities=[],
|
||||
implied_relationships=[],
|
||||
implied_terms=[],
|
||||
)
|
||||
|
||||
|
||||
def prepare_llm_document_content(
|
||||
document_classification_content: KGClassificationContent,
|
||||
category_list: str,
|
||||
category_definitions: dict[str, Dict[str, str | bool]],
|
||||
) -> KGDocumentClassificationPrompt:
|
||||
"""
|
||||
Prepare the content for the extraction classification.
|
||||
"""
|
||||
|
||||
category_definition_string = ""
|
||||
for category, category_data in category_definitions.items():
|
||||
category_definition_string += f"{category}: {category_data['description']}\n"
|
||||
|
||||
if document_classification_content.source_type == "fireflies":
|
||||
return prepare_llm_document_content_fireflies(
|
||||
document_classification_content, category_list, category_definition_string
|
||||
)
|
||||
|
||||
else:
|
||||
return KGDocumentClassificationPrompt(
|
||||
llm_prompt=None,
|
||||
)
|
||||
41
backend/onyx/kg/utils/connector_utils.py
Normal file
41
backend/onyx/kg/utils/connector_utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_unprocessed_connector_ids(tenant_id: str) -> list[int]:
|
||||
"""
|
||||
Retrieves a list of connector IDs that have not been KG processed for a given tenant.
|
||||
|
||||
Args:
|
||||
tenant_id (str): The ID of the tenant to check for unprocessed connectors
|
||||
|
||||
Returns:
|
||||
list[int]: List of connector IDs that have enabled KG extraction but have unprocessed documents
|
||||
"""
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Find connectors that:
|
||||
# 1. Have KG extraction enabled
|
||||
# 2. Have documents that haven't been KG processed
|
||||
stmt = (
|
||||
select(Connector.id)
|
||||
.distinct()
|
||||
.join(DocumentByConnectorCredentialPair)
|
||||
.where(
|
||||
Connector.kg_extraction_enabled,
|
||||
DocumentByConnectorCredentialPair.has_been_kg_processed.is_(False),
|
||||
)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
return [row[0] for row in result.fetchall()]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching unprocessed connector IDs: {str(e)}")
|
||||
return []
|
||||
23
backend/onyx/kg/utils/embeddings.py
Normal file
23
backend/onyx/kg/utils/embeddings.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbedTextType
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
|
||||
def encode_string_batch(strings: List[str]) -> np.ndarray:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=current_search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
# Get embeddings while session is still open
|
||||
embedding = model.encode(strings, text_type=EmbedTextType.QUERY)
|
||||
return np.array(embedding)
|
||||
123
backend/onyx/kg/utils/formatting_utils.py
Normal file
123
backend/onyx/kg/utils/formatting_utils.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from onyx.configs.kg_configs import KG_OWN_COMPANY
|
||||
from onyx.configs.kg_configs import KG_OWN_EMAIL_DOMAINS
|
||||
from onyx.kg.models import KGAggregatedExtractions
|
||||
from onyx.kg.models import KGPerson
|
||||
|
||||
|
||||
def format_entity(entity: str) -> str:
|
||||
if len(entity.split(":")) == 2:
|
||||
entity_type, entity_name = entity.split(":")
|
||||
return f"{entity_type.upper()}:{entity_name.title()}"
|
||||
else:
|
||||
return entity
|
||||
|
||||
|
||||
def format_relationship(relationship: str) -> str:
|
||||
source_node, relationship_type, target_node = relationship.split("__")
|
||||
return (
|
||||
f"{format_entity(source_node)}__"
|
||||
f"{relationship_type.lower()}__"
|
||||
f"{format_entity(target_node)}"
|
||||
)
|
||||
|
||||
|
||||
def format_relationship_type(relationship_type: str) -> str:
|
||||
source_node_type, relationship_type, target_node_type = relationship_type.split(
|
||||
"__"
|
||||
)
|
||||
return (
|
||||
f"{source_node_type.upper()}__"
|
||||
f"{relationship_type.lower()}__"
|
||||
f"{target_node_type.upper()}"
|
||||
)
|
||||
|
||||
|
||||
def generate_relationship_type(relationship: str) -> str:
|
||||
source_node, relationship_type, target_node = relationship.split("__")
|
||||
return (
|
||||
f"{source_node.split(':')[0].upper()}__"
|
||||
f"{relationship_type.lower()}__"
|
||||
f"{target_node.split(':')[0].upper()}"
|
||||
)
|
||||
|
||||
|
||||
def aggregate_kg_extractions(
|
||||
connector_aggregated_kg_extractions_list: list[KGAggregatedExtractions],
|
||||
) -> KGAggregatedExtractions:
|
||||
aggregated_kg_extractions = KGAggregatedExtractions(
|
||||
grounded_entities_document_ids=defaultdict(str),
|
||||
entities=defaultdict(int),
|
||||
relationships=defaultdict(int),
|
||||
terms=defaultdict(int),
|
||||
)
|
||||
|
||||
for connector_aggregated_kg_extractions in connector_aggregated_kg_extractions_list:
|
||||
for (
|
||||
grounded_entity,
|
||||
document_id,
|
||||
) in connector_aggregated_kg_extractions.grounded_entities_document_ids.items():
|
||||
aggregated_kg_extractions.grounded_entities_document_ids[
|
||||
grounded_entity
|
||||
] = document_id
|
||||
|
||||
for entity, count in connector_aggregated_kg_extractions.entities.items():
|
||||
aggregated_kg_extractions.entities[entity] += count
|
||||
for (
|
||||
relationship,
|
||||
count,
|
||||
) in connector_aggregated_kg_extractions.relationships.items():
|
||||
aggregated_kg_extractions.relationships[relationship] += count
|
||||
for term, count in connector_aggregated_kg_extractions.terms.items():
|
||||
aggregated_kg_extractions.terms[term] += count
|
||||
return aggregated_kg_extractions
|
||||
|
||||
|
||||
def kg_email_processing(email: str) -> KGPerson:
|
||||
"""
|
||||
Process the email.
|
||||
"""
|
||||
name, company_domain = email.split("@")
|
||||
|
||||
assert isinstance(KG_OWN_EMAIL_DOMAINS, list)
|
||||
|
||||
employee = any(domain in company_domain for domain in KG_OWN_EMAIL_DOMAINS)
|
||||
if employee:
|
||||
company = KG_OWN_COMPANY
|
||||
else:
|
||||
company = company_domain.capitalize()
|
||||
|
||||
return KGPerson(name=name, company=company, employee=employee)
|
||||
|
||||
|
||||
def generalize_entities(entities: list[str]) -> set[str]:
|
||||
"""
|
||||
Generalize entities to their superclass.
|
||||
"""
|
||||
return set([f"{entity.split(':')[0]}:*" for entity in entities])
|
||||
|
||||
|
||||
def generalize_relationships(relationships: list[str]) -> set[str]:
|
||||
"""
|
||||
Generalize relationships to their superclass.
|
||||
"""
|
||||
generalized_relationships: set[str] = set()
|
||||
for relationship in relationships:
|
||||
assert (
|
||||
len(relationship.split("__")) == 3
|
||||
), "Relationship is not in the correct format"
|
||||
source_entity, relationship_type, target_entity = relationship.split("__")
|
||||
generalized_source_entity = list(generalize_entities([source_entity]))[0]
|
||||
generalized_target_entity = list(generalize_entities([target_entity]))[0]
|
||||
generalized_relationships.add(
|
||||
f"{generalized_source_entity}__{relationship_type}__{target_entity}"
|
||||
)
|
||||
generalized_relationships.add(
|
||||
f"{source_entity}__{relationship_type}__{generalized_target_entity}"
|
||||
)
|
||||
generalized_relationships.add(
|
||||
f"{generalized_source_entity}__{relationship_type}__{generalized_target_entity}"
|
||||
)
|
||||
|
||||
return generalized_relationships
|
||||
179
backend/onyx/kg/vespa/vespa_interactions.py
Normal file
179
backend/onyx/kg/vespa/vespa_interactions.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
|
||||
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
|
||||
from onyx.document_index.vespa.index import IndexFilters
|
||||
from onyx.kg.context_preparations_extraction.fireflies import (
|
||||
get_classification_content_from_fireflies_chunks,
|
||||
)
|
||||
from onyx.kg.models import KGChunkFormat
|
||||
from onyx.kg.models import KGClassificationContent
|
||||
|
||||
|
||||
def get_document_classification_content_for_kg_processing(
|
||||
document_ids: list[str],
|
||||
index_name: str,
|
||||
batch_size: int = 8,
|
||||
num_classification_chunks: int = 3,
|
||||
) -> Generator[list[KGClassificationContent], None, None]:
|
||||
"""
|
||||
Generates the content used for initial classification of a document from
|
||||
the first num_classification_chunks chunks.
|
||||
"""
|
||||
|
||||
classification_content_list: list[KGClassificationContent] = []
|
||||
|
||||
for i in range(0, len(document_ids), batch_size):
|
||||
batch_document_ids = document_ids[i : i + batch_size]
|
||||
for document_id in batch_document_ids:
|
||||
# ... existing code for getting chunks and processing ...
|
||||
first_num_classification_chunks: list[dict] = _get_chunks_via_visit_api(
|
||||
chunk_request=VespaChunkRequest(
|
||||
document_id=document_id,
|
||||
max_chunk_ind=num_classification_chunks - 1,
|
||||
min_chunk_ind=0,
|
||||
),
|
||||
index_name=index_name,
|
||||
filters=IndexFilters(access_control_list=None),
|
||||
field_names=[
|
||||
"document_id",
|
||||
"chunk_id",
|
||||
"title",
|
||||
"content",
|
||||
"metadata",
|
||||
"source_type",
|
||||
"primary_owners",
|
||||
"secondary_owners",
|
||||
],
|
||||
get_large_chunks=False,
|
||||
)
|
||||
|
||||
first_num_classification_chunks = sorted(
|
||||
first_num_classification_chunks, key=lambda x: x["fields"]["chunk_id"]
|
||||
)[:num_classification_chunks]
|
||||
|
||||
classification_content = _get_classification_content_from_chunks(
|
||||
first_num_classification_chunks
|
||||
)
|
||||
|
||||
classification_content_list.append(
|
||||
KGClassificationContent(
|
||||
document_id=document_id,
|
||||
classification_content=classification_content,
|
||||
source_type=first_num_classification_chunks[0]["fields"][
|
||||
"source_type"
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# Yield the batch of classification content
|
||||
if classification_content_list:
|
||||
yield classification_content_list
|
||||
classification_content_list = []
|
||||
|
||||
# Yield any remaining items
|
||||
if classification_content_list:
|
||||
yield classification_content_list
|
||||
|
||||
|
||||
def get_document_chunks_for_kg_processing(
|
||||
document_id: str,
|
||||
index_name: str,
|
||||
batch_size: int = 8,
|
||||
) -> Generator[list[KGChunkFormat], None, None]:
|
||||
"""
|
||||
Retrieves chunks from Vespa for the given document IDs and converts them to KGChunks.
|
||||
|
||||
Args:
|
||||
document_ids (list[str]): List of document IDs to fetch chunks for
|
||||
index_name (str): Name of the Vespa index
|
||||
tenant_id (str): ID of the tenant
|
||||
|
||||
Yields:
|
||||
list[KGChunk]: Batches of chunks ready for KG processing
|
||||
"""
|
||||
|
||||
current_batch: list[KGChunkFormat] = []
|
||||
|
||||
# get all chunks for the document
|
||||
chunks = _get_chunks_via_visit_api(
|
||||
chunk_request=VespaChunkRequest(document_id=document_id),
|
||||
index_name=index_name,
|
||||
filters=IndexFilters(access_control_list=None),
|
||||
field_names=[
|
||||
"document_id",
|
||||
"chunk_id",
|
||||
"title",
|
||||
"content",
|
||||
"metadata",
|
||||
"primary_owners",
|
||||
"secondary_owners",
|
||||
"source_type",
|
||||
"kg_entities",
|
||||
"kg_relationships",
|
||||
"kg_terms",
|
||||
],
|
||||
get_large_chunks=False,
|
||||
)
|
||||
|
||||
# Convert Vespa chunks to KGChunks
|
||||
# kg_chunks: list[KGChunkFormat] = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
fields = chunk["fields"]
|
||||
if isinstance(fields.get("metadata", {}), str):
|
||||
fields["metadata"] = json.loads(fields["metadata"])
|
||||
current_batch.append(
|
||||
KGChunkFormat(
|
||||
connector_id=None, # We may need to adjust this
|
||||
document_id=fields.get("document_id"),
|
||||
chunk_id=fields.get("chunk_id"),
|
||||
primary_owners=fields.get("primary_owners", []),
|
||||
secondary_owners=fields.get("secondary_owners", []),
|
||||
source_type=fields.get("source_type", ""),
|
||||
title=fields.get("title", ""),
|
||||
content=fields.get("content", ""),
|
||||
metadata=fields.get("metadata", {}),
|
||||
entities=fields.get("kg_entities", {}),
|
||||
relationships=fields.get("kg_relationships", {}),
|
||||
terms=fields.get("kg_terms", {}),
|
||||
)
|
||||
)
|
||||
|
||||
if len(current_batch) >= batch_size:
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
|
||||
# Yield any remaining chunks
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
|
||||
def _get_classification_content_from_chunks(
|
||||
first_num_classification_chunks: list[dict],
|
||||
) -> str:
|
||||
"""
|
||||
Creates a KGClassificationContent object from a list of chunks.
|
||||
"""
|
||||
|
||||
source_type = first_num_classification_chunks[0]["fields"]["source_type"]
|
||||
|
||||
if source_type == "fireflies":
|
||||
classification_content = get_classification_content_from_fireflies_chunks(
|
||||
first_num_classification_chunks
|
||||
)
|
||||
|
||||
else:
|
||||
classification_content = (
|
||||
first_num_classification_chunks[0]["fields"]["title"]
|
||||
+ "\n"
|
||||
+ "\n".join(
|
||||
[
|
||||
chunk_content["fields"]["content"]
|
||||
for chunk_content in first_num_classification_chunks
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return classification_content
|
||||
1035
backend/onyx/prompts/kg_prompts.py
Normal file
1035
backend/onyx/prompts/kg_prompts.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -75,12 +75,17 @@ class SearchToolOverrideKwargs(BaseModel):
|
||||
precomputed_keywords: list[str] | None = None
|
||||
user_file_ids: list[int] | None = None
|
||||
user_folder_ids: list[int] | None = None
|
||||
document_sources: list[DocumentSource] | None = None
|
||||
time_cutoff: datetime | None = None
|
||||
ordering_only: bool | None = (
|
||||
None # Flag for fast path when search is only needed for ordering
|
||||
)
|
||||
document_sources: list[DocumentSource] | None = None
|
||||
time_cutoff: datetime | None = None
|
||||
expanded_queries: QueryExpansions | None = None
|
||||
kg_entities: list[str] | None = None
|
||||
kg_relationships: list[str] | None = None
|
||||
kg_terms: list[str] | None = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -309,6 +309,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
document_sources = override_kwargs.document_sources
|
||||
time_cutoff = override_kwargs.time_cutoff
|
||||
expanded_queries = override_kwargs.expanded_queries
|
||||
kg_entities = override_kwargs.kg_entities
|
||||
kg_relationships = override_kwargs.kg_relationships
|
||||
kg_terms = override_kwargs.kg_terms
|
||||
|
||||
# Fast path for ordering-only search
|
||||
if ordering_only:
|
||||
@@ -358,6 +361,16 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
# Overwrite time-cutoff should supercede existing time-cutoff, even if defined
|
||||
retrieval_options.filters.time_cutoff = time_cutoff
|
||||
|
||||
# Initialize kg filters in retrieval options and filters with all provided values
|
||||
retrieval_options = retrieval_options or RetrievalDetails()
|
||||
retrieval_options.filters = retrieval_options.filters or BaseFilters()
|
||||
if kg_entities:
|
||||
retrieval_options.filters.kg_entities = kg_entities
|
||||
if kg_relationships:
|
||||
retrieval_options.filters.kg_relationships = kg_relationships
|
||||
if kg_terms:
|
||||
retrieval_options.filters.kg_terms = kg_terms
|
||||
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
|
||||
@@ -80,6 +80,7 @@ slack-sdk==3.20.2
|
||||
SQLAlchemy[mypy]==2.0.15
|
||||
starlette==0.46.1
|
||||
supervisor==4.2.5
|
||||
thefuzz==0.22.1
|
||||
tiktoken==0.7.0
|
||||
timeago==1.0.16
|
||||
transformers==4.49.0
|
||||
|
||||
Reference in New Issue
Block a user