Compare commits

...

41 Commits

Author SHA1 Message Date
joachim-danswer
050c0133c7 bug fix 2025-04-14 13:10:04 -07:00
joachim-danswer
a2bdbd23d8 cluster adjustments for defined relationships 2025-04-14 13:10:04 -07:00
joachim-danswer
f87d44d24e progress 2025-04-14 13:10:04 -07:00
joachim-danswer
53ef95ec69 prog 2025-04-14 13:10:04 -07:00
joachim-danswer
eb48354e8f progress 2025-04-14 13:08:12 -07:00
joachim-danswer
7e12f02b62 dev 2025-04-14 13:08:12 -07:00
joachim-danswer
221a4c19f0 citation work for KG queries 2025-04-14 13:08:12 -07:00
joachim-danswer
2971eb7d59 simple changes 2025-04-14 13:07:46 -07:00
joachim-danswer
e7e786fd65 further graph dev 2025-04-14 13:07:46 -07:00
joachim-danswer
baf4dd64b0 more inference optimizations 2025-04-14 13:07:46 -07:00
joachim-danswer
68853393ee improved SQL generation + more 2025-04-14 13:07:46 -07:00
joachim-danswer
defcf8291a improved/fixed extraction 2025-04-14 13:07:46 -07:00
joachim-danswer
2ef9d19160 with base data 2025-04-14 13:07:20 -07:00
joachim-danswer
3acc069511 divcon prototype + more 2025-04-14 13:04:37 -07:00
joachim-danswer
336b31fc1e prep new agent 2025-04-14 13:00:08 -07:00
joachim-danswer
e22d414d33 migration fix 2025-04-14 12:57:03 -07:00
joachim-danswer
fcd749ab29 conf's from env 2025-04-14 12:57:03 -07:00
joachim-danswer
365b9b09e3 db-retrieval of determined/ungrounded ge matching 2025-04-14 12:57:03 -07:00
joachim-danswer
862807c13c clustering for determined entitied 2025-04-14 12:57:03 -07:00
joachim-danswer
15a095a068 ungrounded grounded entities 2025-04-14 12:57:03 -07:00
joachim-danswer
4600788476 document classification 2025-04-14 12:57:03 -07:00
joachim-danswer
99b6b7dd11 vendor vs account 2025-04-14 12:57:03 -07:00
joachim-danswer
49847b05f8 prompt updates 2025-04-14 12:57:03 -07:00
joachim-danswer
be601d204a first simple SQL queries & clustering adjustment 2025-04-14 12:57:03 -07:00
joachim-danswer
ea12c25282 add cross-relationships 2025-04-14 12:57:03 -07:00
joachim-danswer
458d7fb124 extraction improvement and more querying 2025-04-14 12:57:03 -07:00
joachim-danswer
4391d05ce3 start kg agent 2025-04-14 12:57:03 -07:00
joachim-danswer
5ec5e616f1 mypy fix 2025-04-14 12:40:43 -07:00
joachim-danswer
2cc87c7d53 e2e extract + cluster 2025-04-14 12:40:43 -07:00
joachim-danswer
c017724e91 cc - pg & vespa 2025-04-14 12:40:43 -07:00
joachim-danswer
e99eac4a1d cc updates 2025-04-14 12:40:43 -07:00
joachim-danswer
da4f348039 fixes 2025-04-14 12:40:43 -07:00
joachim-danswer
740d4a5a9d base extraction to postgres and vespa 2025-04-14 12:40:43 -07:00
joachim-danswer
c7c8330b90 more postgres changes 2025-04-14 12:40:43 -07:00
joachim-danswer
6869f0403d llm extraction -> vespa 2025-04-14 12:39:52 -07:00
joachim-danswer
bacb1092ff more pg setup, & start of prompt/processing 2025-04-14 12:39:52 -07:00
joachim-danswer
1a119601e6 pg updates 2025-04-14 12:39:52 -07:00
joachim-danswer
1980fc62c0 initial KGH PG tables 2025-04-14 12:39:52 -07:00
joachim-danswer
02a4232189 small vespa nits 2025-04-14 12:39:52 -07:00
joachim-danswer
f1cc6841f9 initial vespa interactions 2025-04-14 12:39:52 -07:00
pablonyx
e7d2f9a43a add user files (#4152) 2025-04-14 12:38:57 -07:00
57 changed files with 7858 additions and 14 deletions

View File

@@ -0,0 +1,219 @@
"""create knowlege graph tables
Revision ID: 495cb26ce93e
Revises: 6a804aeb4830
Create Date: 2025-03-19 08:51:14.341989
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "495cb26ce93e"
down_revision = "6a804aeb4830"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"kg_entity_type",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("description", sa.String(), nullable=True),
sa.Column("grounding", sa.String(), nullable=False),
sa.Column("clustering", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column(
"classification_requirements",
postgresql.JSONB,
nullable=False,
server_default="{}",
),
sa.Column("cluster_count", sa.Integer(), nullable=True),
sa.Column(
"extraction_sources", postgresql.JSONB, nullable=False, server_default="{}"
),
sa.Column("active", sa.Boolean(), nullable=False, default=False),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.Column("grounded_source_name", sa.String(), nullable=True, unique=True),
sa.Column(
"ge_determine_instructions", postgresql.ARRAY(sa.String()), nullable=True
),
sa.Column("ge_grounding_signature", sa.String(), nullable=True),
)
# Create KGRelationshipType table
op.create_table(
"kg_relationship_type",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("name", sa.String(), nullable=False, index=True),
sa.Column(
"source_entity_type_id_name", sa.String(), nullable=False, index=True
),
sa.Column(
"target_entity_type_id_name", sa.String(), nullable=False, index=True
),
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
sa.Column("clustering", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("cluster_count", sa.Integer(), nullable=True),
sa.Column("type", sa.String(), nullable=False, index=True),
sa.Column("active", sa.Boolean(), nullable=False, default=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.ForeignKeyConstraint(
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
),
sa.ForeignKeyConstraint(
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
),
)
# Create KGEntity table
op.create_table(
"kg_entity",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("name", sa.String(), nullable=False, index=True),
sa.Column("document_id", sa.String(), nullable=True, index=True),
sa.Column(
"alternative_names",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
sa.Column("description", sa.String(), nullable=True),
sa.Column(
"keywords",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column("cluster_count", sa.Integer(), nullable=True),
sa.Column(
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
),
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
)
op.create_index("ix_entity_type_acl", "kg_entity", ["entity_type_id_name", "acl"])
op.create_index(
"ix_entity_name_search", "kg_entity", ["name", "entity_type_id_name"]
)
# Create KGRelationship table
op.create_table(
"kg_relationship",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("source_node", sa.String(), nullable=False, index=True),
sa.Column("target_node", sa.String(), nullable=False, index=True),
sa.Column("type", sa.String(), nullable=False, index=True),
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
sa.Column("cluster_count", sa.Integer(), nullable=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.ForeignKeyConstraint(["source_node"], ["kg_entity.id_name"]),
sa.ForeignKeyConstraint(["target_node"], ["kg_entity.id_name"]),
sa.ForeignKeyConstraint(
["relationship_type_id_name"], ["kg_relationship_type.id_name"]
),
sa.UniqueConstraint(
"source_node",
"target_node",
"type",
name="uq_kg_relationship_source_target_type",
),
)
op.create_index(
"ix_kg_relationship_nodes", "kg_relationship", ["source_node", "target_node"]
)
# Create KGTerm table
op.create_table(
"kg_term",
sa.Column("id_term", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column(
"entity_types",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
)
op.create_index("ix_search_term_entities", "kg_term", ["entity_types"])
op.create_index("ix_search_term_term", "kg_term", ["id_term"])
op.add_column(
"document",
sa.Column("kg_processed", sa.Boolean(), nullable=False, server_default="false"),
)
op.add_column(
"document",
sa.Column("kg_data", postgresql.JSONB(), nullable=False, server_default="{}"),
)
op.add_column(
"connector",
sa.Column(
"kg_extraction_enabled",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
op.add_column(
"document_by_connector_credential_pair",
sa.Column("has_been_kg_processed", sa.Boolean(), nullable=True),
)
def downgrade() -> None:
# Drop tables in reverse order of creation to handle dependencies
op.drop_table("kg_term")
op.drop_table("kg_relationship")
op.drop_table("kg_entity")
op.drop_table("kg_relationship_type")
op.drop_table("kg_entity_type")
op.drop_column("connector", "kg_extraction_enabled")
op.drop_column("document_by_connector_credential_pair", "has_been_kg_processed")
op.drop_column("document", "kg_data")
op.drop_column("document", "kg_processed")

View File

@@ -0,0 +1,18 @@
from pydantic import BaseModel
class KGQuestionEntityExtractionResult(BaseModel):
entities: list[str]
terms: list[str]
time_filter: str | None
class KGQuestionRelationshipExtractionResult(BaseModel):
relationships: list[str]
class KGQuestionExtractionResult(BaseModel):
entities: list[str]
relationships: list[str]
terms: list[str]
time_filter: str | None

View File

@@ -37,7 +37,7 @@ def research_object_source(
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.search_request.query
graph_config.inputs.search_request.query
object, document_source = state.object_source_combination
if search_tool is None or graph_config.inputs.search_request.persona is None:

View File

@@ -17,6 +17,9 @@ def research(
search_tool: SearchTool,
document_sources: list[DocumentSource] | None = None,
time_cutoff: datetime | None = None,
kg_entities: list[str] | None = None,
kg_relationships: list[str] | None = None,
kg_terms: list[str] | None = None,
) -> list[LlmDoc]:
# new db session to avoid concurrency issues
@@ -33,6 +36,9 @@ def research(
skip_query_analysis=True,
document_sources=document_sources,
time_cutoff=time_cutoff,
kg_entities=kg_entities,
kg_relationships=kg_relationships,
kg_terms=kg_terms,
),
):
# get retrieved docs to send to the rest of the graph

View File

@@ -0,0 +1,54 @@
from collections.abc import Hashable
from datetime import datetime
from typing import Literal
from langgraph.types import Send
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
def simple_vs_search(
state: MainState,
) -> Literal["process_kg_only_answers", "construct_deep_search_filters"]:
if state.strategy == KGAnswerStrategy.DEEP or len(state.relationships) > 0:
return "construct_deep_search_filters"
else:
return "process_kg_only_answers"
def research_individual_object(
state: MainState,
# ) -> list[Send | Hashable] | Literal["individual_deep_search"]:
) -> list[Send | Hashable]:
edge_start_time = datetime.now()
# if (
# not state.div_con_entities
# or not state.broken_down_question
# or not state.vespa_filter_results
# ):
# return "individual_deep_search"
# else:
assert state.div_con_entities is not None
assert state.broken_down_question is not None
assert state.vespa_filter_results is not None
return [
Send(
"process_individual_deep_search",
ResearchObjectInput(
entity=entity,
broken_down_question=state.broken_down_question,
vespa_filter_results=state.vespa_filter_results,
source_division=state.source_division,
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for entity in state.div_con_entities
]

View File

@@ -0,0 +1,138 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.kb_search.conditional_edges import (
research_individual_object,
)
from onyx.agents.agent_search.kb_search.conditional_edges import simple_vs_search
from onyx.agents.agent_search.kb_search.nodes.a1_extract_ert import extract_ert
from onyx.agents.agent_search.kb_search.nodes.a2_analyze import analyze
from onyx.agents.agent_search.kb_search.nodes.a3_generate_simple_sql import (
generate_simple_sql,
)
from onyx.agents.agent_search.kb_search.nodes.b1_construct_deep_search_filters import (
construct_deep_search_filters,
)
from onyx.agents.agent_search.kb_search.nodes.b2p_process_individual_deep_search import (
process_individual_deep_search,
)
from onyx.agents.agent_search.kb_search.nodes.b3_consoldidate_individual_deep_search import (
consoldidate_individual_deep_search,
)
from onyx.agents.agent_search.kb_search.nodes.c1_process_kg_only_answers import (
process_kg_only_answers,
)
from onyx.agents.agent_search.kb_search.nodes.d1_generate_answer import generate_answer
from onyx.agents.agent_search.kb_search.states import MainInput
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def kb_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for the knowledge graph search process.
"""
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
### Add nodes ###
graph.add_node(
"extract_ert",
extract_ert,
)
graph.add_node(
"generate_simple_sql",
generate_simple_sql,
)
graph.add_node(
"analyze",
analyze,
)
graph.add_node(
"generate_answer",
generate_answer,
)
graph.add_node(
"construct_deep_search_filters",
construct_deep_search_filters,
)
graph.add_node(
"process_individual_deep_search",
process_individual_deep_search,
)
# graph.add_node(
# "individual_deep_search",
# individual_deep_search,
# )
graph.add_node(
"consoldidate_individual_deep_search",
consoldidate_individual_deep_search,
)
graph.add_node("process_kg_only_answers", process_kg_only_answers)
### Add edges ###
graph.add_edge(start_key=START, end_key="extract_ert")
graph.add_edge(
start_key="extract_ert",
end_key="analyze",
)
graph.add_edge(
start_key="analyze",
end_key="generate_simple_sql",
)
graph.add_conditional_edges("generate_simple_sql", simple_vs_search)
graph.add_edge(start_key="process_kg_only_answers", end_key="generate_answer")
# graph.add_edge(
# start_key="construct_deep_search_filters",
# end_key="process_individual_deep_search",
# )
graph.add_conditional_edges(
source="construct_deep_search_filters",
path=research_individual_object,
path_map=["process_individual_deep_search"],
)
graph.add_edge(
start_key="process_individual_deep_search",
end_key="consoldidate_individual_deep_search",
)
# graph.add_edge(
# start_key="individual_deep_search",
# end_key="consoldidate_individual_deep_search",
# )
graph.add_edge(
start_key="consoldidate_individual_deep_search", end_key="generate_answer"
)
graph.add_edge(
start_key="generate_answer",
end_key=END,
)
return graph

View File

@@ -0,0 +1,217 @@
from typing import Set
from onyx.agents.agent_search.kb_search.models import KGExpendedGraphObjects
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.relationships import get_relationships_of_entity
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _check_entities_disconnected(
current_entities: list[str], current_relationships: list[str]
) -> bool:
"""
Check if all entities in current_entities are disconnected via the given relationships.
Relationships are in the format: source_entity__relationship_name__target_entity
Args:
current_entities: List of entity IDs to check connectivity for
current_relationships: List of relationships in format source__relationship__target
Returns:
bool: True if all entities are disconnected, False otherwise
"""
if not current_entities:
return True
# Create a graph representation using adjacency list
graph: dict[str, set[str]] = {entity: set() for entity in current_entities}
# Build the graph from relationships
for relationship in current_relationships:
try:
source, _, target = relationship.split("__")
if source in graph and target in graph:
graph[source].add(target)
graph[target].add(source) # Add reverse edge since graph is undirected
except ValueError:
raise ValueError(f"Invalid relationship format: {relationship}")
# Use BFS to check if all entities are connected
visited: set[str] = set()
start_entity = current_entities[0]
def _bfs(start: str) -> None:
queue = [start]
visited.add(start)
while queue:
current = queue.pop(0)
for neighbor in graph[current]:
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
# Start BFS from the first entity
_bfs(start_entity)
logger.debug(f"Number of visited entities: {len(visited)}")
# Check if all current_entities are in visited
return not all(entity in visited for entity in current_entities)
def create_minimal_connected_query_graph(
entities: list[str], relationships: list[str], max_depth: int = 2
) -> KGExpendedGraphObjects:
"""
Find the minimal subgraph that connects all input entities, using only general entities
(<entity_type>:*) as intermediate nodes. The subgraph will include only the relationships
necessary to connect all input entities through the shortest possible paths.
Args:
entities: Initial list of entity IDs
relationships: Initial list of relationships in format source__relationship__target
max_depth: Maximum depth to expand the graph (default: 2)
Returns:
KGExpendedGraphObjects containing expanded entities and relationships
"""
# Create copies of input lists to avoid modifying originals
expanded_entities = entities.copy()
expanded_relationships = relationships.copy()
# Keep track of original entities
original_entities = set(entities)
# Build initial graph from existing relationships
graph: dict[str, set[tuple[str, str]]] = {
entity: set() for entity in expanded_entities
}
for rel in relationships:
try:
source, rel_name, target = rel.split("__")
if source in graph and target in graph:
graph[source].add((target, rel_name))
graph[target].add((source, rel_name))
except ValueError:
continue
# For each depth level
counter = 0
while counter < max_depth:
# Find all connected components in the current graph
components = []
visited = set()
def dfs(node: str, component: set[str]) -> None:
visited.add(node)
component.add(node)
for neighbor, _ in graph.get(node, set()):
if neighbor not in visited:
dfs(neighbor, component)
# Find all components
for entity in expanded_entities:
if entity not in visited:
component: Set[str] = set()
dfs(entity, component)
components.append(component)
# If we only have one component, we're done
if len(components) == 1:
break
# Find the shortest path between any two components using general entities
shortest_path = None
shortest_path_length = float("inf")
for comp1 in components:
for comp2 in components:
if comp1 == comp2:
continue
# Try to find path between entities in different components
for entity1 in comp1:
if not any(e in original_entities for e in comp1):
continue
# entity1_type = entity1.split(":")[0]
with get_session_with_current_tenant() as db_session:
entity1_rels = get_relationships_of_entity(db_session, entity1)
for rel1 in entity1_rels:
try:
source1, rel_name1, target1 = rel1.split("__")
if source1 != entity1:
continue
target1_type = target1.split(":")[0]
general_target = f"{target1_type}:*"
# Try to find path from general_target to comp2
for entity2 in comp2:
if not any(e in original_entities for e in comp2):
continue
with get_session_with_current_tenant() as db_session:
entity2_rels = get_relationships_of_entity(
db_session, entity2
)
for rel2 in entity2_rels:
try:
source2, rel_name2, target2 = rel2.split("__")
if target2 != entity2:
continue
source2_type = source2.split(":")[0]
general_source = f"{source2_type}:*"
if general_target == general_source:
# Found a path of length 2
path = [
(entity1, rel_name1, general_target),
(general_target, rel_name2, entity2),
]
if len(path) < shortest_path_length:
shortest_path = path
shortest_path_length = len(path)
except ValueError:
continue
except ValueError:
continue
# If we found a path, add it to our graph
if shortest_path:
for source, rel_name, target in shortest_path:
# Add general entity if needed
if ":*" in source and source not in expanded_entities:
expanded_entities.append(source)
if ":*" in target and target not in expanded_entities:
expanded_entities.append(target)
# Add relationship
rel = f"{source}__{rel_name}__{target}"
if rel not in expanded_relationships:
expanded_relationships.append(rel)
# Update graph
if source not in graph:
graph[source] = set()
if target not in graph:
graph[target] = set()
graph[source].add((target, rel_name))
graph[target].add((source, rel_name))
counter += 1
logger.debug(f"Number of expanded entities: {len(expanded_entities)}")
logger.debug(f"Number of expanded relationships: {len(expanded_relationships)}")
return KGExpendedGraphObjects(
entities=expanded_entities, relationships=expanded_relationships
)

View File

@@ -0,0 +1,34 @@
from pydantic import BaseModel
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import YesNoEnum
class KGQuestionEntityExtractionResult(BaseModel):
entities: list[str]
terms: list[str]
time_filter: str | None
class KGAnswerApproach(BaseModel):
strategy: KGAnswerStrategy
format: KGAnswerFormat
broken_down_question: str | None = None
divide_and_conquer: YesNoEnum | None = None
class KGQuestionRelationshipExtractionResult(BaseModel):
relationships: list[str]
class KGQuestionExtractionResult(BaseModel):
entities: list[str]
relationships: list[str]
terms: list[str]
time_filter: str | None
class KGExpendedGraphObjects(BaseModel):
entities: list[str]
relationships: list[str]

View File

@@ -0,0 +1,240 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
from onyx.agents.agent_search.kb_search.models import (
KGQuestionRelationshipExtractionResult,
)
from onyx.agents.agent_search.kb_search.states import (
ERTExtractionUpdate,
)
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.relationships import get_allowed_relationship_type_pairs
from onyx.kg.extractions.extraction_processing import get_entity_types_str
from onyx.kg.extractions.extraction_processing import get_relationship_types_str
from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def extract_ert(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ERTExtractionUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
# first four lines duplicates from generate_initial_answer
question = graph_config.inputs.search_request.query
today_date = datetime.now().strftime("%A, %Y-%m-%d")
all_entity_types = get_entity_types_str(active=None)
all_relationship_types = get_relationship_types_str(active=None)
### get the entities, terms, and filters
query_extraction_pre_prompt = QUERY_ENTITY_EXTRACTION_PROMPT.format(
entity_types=all_entity_types,
relationship_types=all_relationship_types,
)
query_extraction_prompt = query_extraction_pre_prompt.replace(
"---content---", question
).replace("---today_date---", today_date)
msg = [
HumanMessage(
content=query_extraction_prompt,
)
]
fast_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
15,
fast_llm.invoke,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
entity_extraction_result = (
KGQuestionEntityExtractionResult.model_validate_json(cleaned_response)
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
entity_extraction_result = KGQuestionEntityExtractionResult(
entities=[],
terms=[],
time_filter="",
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
entity_extraction_result = KGQuestionEntityExtractionResult(
entities=[],
terms=[],
time_filter="",
)
ert_entities_string = f"Entities: {entity_extraction_result.entities}\n"
# ert_terms_string = f"Terms: {entity_extraction_result.terms}"
# ert_time_filter_string = f"Time Filter: {entity_extraction_result.time_filter}\n"
### get the relationships
# find the relationship types that match the extracted entity types
with get_session_with_current_tenant() as db_session:
allowed_relationship_pairs = get_allowed_relationship_type_pairs(
db_session, entity_extraction_result.entities
)
query_relationship_extraction_prompt = (
QUERY_RELATIONSHIP_EXTRACTION_PROMPT.replace("---question---", question)
.replace("---today_date---", today_date)
.replace(
"---relationship_type_options---",
" - " + "\n - ".join(allowed_relationship_pairs),
)
.replace("---identified_entities---", ert_entities_string)
.replace("---entity_types---", all_entity_types)
)
msg = [
HumanMessage(
content=query_relationship_extraction_prompt,
)
]
fast_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
15,
fast_llm.invoke,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = cleaned_response.replace("{{", '{"')
cleaned_response = cleaned_response.replace("}}", '"}')
try:
relationship_extraction_result = (
KGQuestionRelationshipExtractionResult.model_validate_json(
cleaned_response
)
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
relationships=[],
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
relationships=[],
)
# ert_relationships_string = (
# f"Relationships: {relationship_extraction_result.relationships}\n"
# )
##
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=ert_entities_string,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=ert_relationships_string,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=ert_terms_string,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=ert_time_filter_string,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# dispatch_main_answer_stop_info(0, writer)
return ERTExtractionUpdate(
entities_types_str=all_entity_types,
entities=entity_extraction_result.entities,
relationships=relationship_extraction_result.relationships,
terms=entity_extraction_result.terms,
time_filter=entity_extraction_result.time_filter,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="extract entities terms",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,210 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import (
create_minimal_connected_query_graph,
)
from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import YesNoEnum
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.kg.clustering.normalizations import normalize_entities
from onyx.kg.clustering.normalizations import normalize_relationships
from onyx.kg.clustering.normalizations import normalize_terms
from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def analyze(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> AnalysisUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
entities = state.entities
relationships = state.relationships
terms = state.terms
time_filter = state.time_filter
normalized_entities = normalize_entities(entities)
normalized_relationships = normalize_relationships(
relationships, normalized_entities.entity_normalization_map
)
normalized_terms = normalize_terms(terms)
normalized_time_filter = time_filter
# Expand the entities and relationships to make sure that entities are connected
graph_expansion = create_minimal_connected_query_graph(
normalized_entities.entities,
normalized_relationships.relationships,
max_depth=2,
)
query_graph_entities = graph_expansion.entities
query_graph_relationships = graph_expansion.relationships
# Evaluate whether a search needs to be done after identifying all entities and relationships
strategy_generation_prompt = (
STRATEGY_GENERATION_PROMPT.replace(
"---entities---", "\n".join(query_graph_entities)
)
.replace("---relationships---", "\n".join(query_graph_relationships))
.replace("---terms---", "\n".join(normalized_terms.terms))
.replace("---question---", question)
)
msg = [
HumanMessage(
content=strategy_generation_prompt,
)
]
# fast_llm = graph_config.tooling.fast_llm
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
20,
# fast_llm.invoke,
primary_llm.invoke,
prompt=msg,
timeout_override=5,
max_tokens=100,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
approach_extraction_result = KGAnswerApproach.model_validate_json(
cleaned_response
)
strategy = approach_extraction_result.strategy
output_format = approach_extraction_result.format
broken_down_question = approach_extraction_result.broken_down_question
divide_and_conquer = approach_extraction_result.divide_and_conquer
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
strategy = KGAnswerStrategy.DEEP
output_format = KGAnswerFormat.TEXT
broken_down_question = None
divide_and_conquer = YesNoEnum.NO
if strategy is None or output_format is None:
raise ValueError(f"Invalid strategy: {cleaned_response}")
except Exception as e:
logger.error(f"Error in strategy generation: {e}")
raise e
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece="\n".join(normalized_entities.entities),
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece="\n".join(normalized_relationships.relationships),
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece="\n".join(query_graph_entities),
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece="\n".join(query_graph_relationships),
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=strategy.value,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=output_format.value,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# dispatch_main_answer_stop_info(0, writer)
return AnalysisUpdate(
normalized_core_entities=normalized_entities.entities,
normalized_core_relationships=normalized_relationships.relationships,
query_graph_entities=query_graph_entities,
query_graph_relationships=query_graph_relationships,
normalized_terms=normalized_terms.terms,
normalized_time_filter=normalized_time_filter,
strategy=strategy,
broken_down_question=broken_down_question,
output_format=output_format,
divide_and_conquer=divide_and_conquer,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="analyze",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,224 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy import text
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import SQLSimpleGenerationUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.engine import get_session_with_current_tenant
from onyx.llm.interfaces import LLM
from onyx.prompts.kg_prompts import SIMPLE_SQL_PROMPT
from onyx.prompts.kg_prompts import SQL_AGGREGATION_REMOVAL_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def _sql_is_aggregate_query(sql_statement: str) -> bool:
return any(
agg_func in sql_statement.upper()
for agg_func in ["COUNT(", "MAX(", "MIN(", "AVG(", "SUM("]
)
def _remove_aggregation(sql_statement: str, llm: LLM) -> str:
"""
Remove aggregate functions from the SQL statement.
"""
sql_aggregation_removal_prompt = SQL_AGGREGATION_REMOVAL_PROMPT.replace(
"---sql_statement---", sql_statement
)
msg = [
HumanMessage(
content=sql_aggregation_removal_prompt,
)
]
# Grader
try:
llm_response = run_with_timeout(
15,
llm.invoke,
prompt=msg,
timeout_override=25,
max_tokens=800,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
sql_statement = cleaned_response.split("SQL:")[1].strip()
sql_statement = sql_statement.split(";")[0].strip() + ";"
sql_statement = sql_statement.replace("sql", "").strip()
except Exception as e:
logger.error(f"Error in strategy generation: {e}")
raise e
return sql_statement
def generate_simple_sql(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> SQLSimpleGenerationUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
entities_types_str = state.entities_types_str
state.strategy
state.output_format
simple_sql_prompt = (
SIMPLE_SQL_PROMPT.replace("---entities_types---", entities_types_str)
.replace("---question---", question)
.replace("---query_entities---", "\n".join(state.query_graph_entities))
.replace(
"---query_relationships---", "\n".join(state.query_graph_relationships)
)
)
msg = [
HumanMessage(
content=simple_sql_prompt,
)
]
fast_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
15,
fast_llm.invoke,
prompt=msg,
timeout_override=25,
max_tokens=800,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
sql_statement = cleaned_response.split("SQL:")[1].strip()
sql_statement = sql_statement.split(";")[0].strip() + ";"
sql_statement = sql_statement.replace("sql", "").strip()
# reasoning = cleaned_response.split("SQL:")[0].strip()
except Exception as e:
logger.error(f"Error in strategy generation: {e}")
raise e
if _sql_is_aggregate_query(sql_statement):
individualized_sql_query = _remove_aggregation(sql_statement, llm=fast_llm)
else:
individualized_sql_query = None
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=reasoning,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=cleaned_response,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# CRITICAL: EXECUTION OF SQL NEEDS TO ME MADE SAFE FOR PRODUCTION
with get_session_with_current_tenant() as db_session:
try:
result = db_session.execute(text(sql_statement))
# Handle scalar results (like COUNT)
if sql_statement.upper().startswith("SELECT COUNT"):
scalar_result = result.scalar()
query_results = (
[{"count": int(scalar_result) - 1}]
if scalar_result is not None
else []
)
else:
# Handle regular row results
rows = result.fetchall()
query_results = [dict(row._mapping) for row in rows]
except Exception as e:
logger.error(f"Error executing SQL query: {e}")
raise e
if (
individualized_sql_query is not None
and individualized_sql_query != sql_statement
):
with get_session_with_current_tenant() as db_session:
try:
result = db_session.execute(text(individualized_sql_query))
# Handle scalar results (like COUNT)
if individualized_sql_query.upper().startswith("SELECT COUNT"):
scalar_result = result.scalar()
individualized_query_results = (
[{"count": int(scalar_result) - 1}]
if scalar_result is not None
else []
)
else:
# Handle regular row results
rows = result.fetchall()
individualized_query_results = [dict(row._mapping) for row in rows]
except Exception as e:
# No stopping here, the individualized SQL query is not mandatory
logger.error(f"Error executing Individualized SQL query: {e}")
individualized_query_results = None
else:
individualized_query_results = None
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=str(query_results),
# level=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# dispatch_main_answer_stop_info(0, writer)
logger.info(f"query_results: {query_results}")
return SQLSimpleGenerationUpdate(
sql_query=sql_statement,
query_results=query_results,
individualized_sql_query=individualized_sql_query,
individualized_query_results=individualized_query_results,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate simple sql",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,158 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
from onyx.agents.agent_search.kb_search.states import KGVespaFilterResults
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import get_entity_types_with_grounded_source_name
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def construct_deep_search_filters(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> DeepSearchFilterUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
entities_types_str = state.entities_types_str
entities = state.query_graph_entities
relationships = state.query_graph_relationships
simple_sql_query = state.sql_query
search_filter_construction_prompt = (
SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
"---entity_type_descriptions---",
entities_types_str,
)
.replace(
"---entity_filters---",
"\n".join(entities),
)
.replace(
"---relationship_filters---",
"\n".join(relationships),
)
.replace(
"---sql_query---",
simple_sql_query or "(no SQL generated)",
)
.replace(
"---question---",
question,
)
)
msg = [
HumanMessage(
content=search_filter_construction_prompt,
)
]
llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
15,
llm.invoke,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = cleaned_response.replace("{{", '{"')
cleaned_response = cleaned_response.replace("}}", '"}')
try:
vespa_filter_results = KGVespaFilterResults.model_validate_json(
cleaned_response
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
vespa_filter_results = KGVespaFilterResults(
entity_filters=[],
relationship_filters=[],
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
vespa_filter_results = KGVespaFilterResults(
entity_filters=[],
relationship_filters=[],
)
if (
state.individualized_query_results
and len(state.individualized_query_results) > 0
):
div_con_entities = [
x["id_name"]
for x in state.individualized_query_results
if x["id_name"] is not None and "*" not in x["id_name"]
]
elif state.query_results and len(state.query_results) > 0:
div_con_entities = [
x["id_name"]
for x in state.query_results
if x["id_name"] is not None and "*" not in x["id_name"]
]
else:
div_con_entities = []
div_con_entities = list(set(div_con_entities))
logger.info(f"div_con_entities: {div_con_entities}")
with get_session_with_current_tenant() as db_session:
double_grounded_entity_types = get_entity_types_with_grounded_source_name(
db_session
)
source_division = False
if div_con_entities:
for entity_type in double_grounded_entity_types:
if entity_type.grounded_source_name.lower() in div_con_entities[0].lower():
source_division = True
break
else:
raise ValueError("No div_con_entities found")
return DeepSearchFilterUpdate(
vespa_filter_results=vespa_filter_results,
div_con_entities=div_con_entities,
source_division=source_division,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="construct deep search filters",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,133 @@
import copy
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
from onyx.agents.agent_search.kb_search.states import ResearchObjectUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.prompts.kg_prompts import KG_OBJECT_SOURCE_RESEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def process_individual_deep_search(
state: ResearchObjectInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ResearchObjectUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = state.broken_down_question
source_division = state.source_division
if not search_tool:
raise ValueError("search_tool is not provided")
object = state.entity.replace(":", ": ").lower()
if source_division:
extended_question = question
else:
extended_question = f"{question} in regards to {object}"
kg_entity_filters = copy.deepcopy(
state.vespa_filter_results.entity_filters + [state.entity]
)
kg_relationship_filters = copy.deepcopy(
state.vespa_filter_results.relationship_filters
)
logger.info("Research for object: " + object)
logger.info(f"kg_entity_filters: {kg_entity_filters}")
logger.info(f"kg_relationship_filters: {kg_relationship_filters}")
# Add random wait between 1-3 seconds
# time.sleep(random.uniform(0, 3))
retrieved_docs = research(
question=extended_question,
kg_entities=kg_entity_filters,
kg_relationships=kg_relationship_filters,
search_tool=search_tool,
)
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num + 1) + ":\n" + doc.content
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# Built prompt
datetime.now().strftime("%A, %Y-%m-%d")
kg_object_source_research_prompt = KG_OBJECT_SOURCE_RESEARCH_PROMPT.format(
question=extended_question,
document_text=document_texts,
)
# Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=kg_object_source_research_prompt,
reserved_str="",
),
)
]
# fast_llm = graph_config.tooling.fast_llm
primary_llm = graph_config.tooling.primary_llm
llm = primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
llm.invoke,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
object_research_results = str(llm_response.content).replace("```json\n", "")
except Exception as e:
raise ValueError(f"Error in research_object_source: {e}")
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
return ResearchObjectUpdate(
research_object_results=[
{
"object": object.replace(":", ": ").capitalize(),
"results": object_research_results,
}
],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="process individual deep search",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,125 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
from onyx.agents.agent_search.kb_search.states import KGVespaFilterResults
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def individual_deep_search(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> DeepSearchFilterUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
entities_types_str = state.entities_types_str
entities = state.query_graph_entities
relationships = state.query_graph_relationships
simple_sql_query = state.sql_query
search_filter_construction_prompt = (
SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
"---entity_type_descriptions---",
entities_types_str,
)
.replace(
"---entity_filters---",
"\n".join(entities),
)
.replace(
"---relationship_filters---",
"\n".join(relationships),
)
.replace(
"---sql_query---",
simple_sql_query or "(no SQL generated)",
)
.replace(
"---question---",
question,
)
)
msg = [
HumanMessage(
content=search_filter_construction_prompt,
)
]
llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
15,
llm.invoke,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = cleaned_response.replace("{{", '{"')
cleaned_response = cleaned_response.replace("}}", '"}')
try:
vespa_filter_results = KGVespaFilterResults.model_validate_json(
cleaned_response
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
vespa_filter_results = KGVespaFilterResults(
entity_filters=[],
relationship_filters=[],
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
vespa_filter_results = KGVespaFilterResults(
entity_filters=[],
relationship_filters=[],
)
if state.query_results:
div_con_entities = [
x["id_name"] for x in state.query_results if x["id_name"] is not None
]
else:
div_con_entities = []
return DeepSearchFilterUpdate(
vespa_filter_results=vespa_filter_results,
div_con_entities=div_con_entities,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="construct deep search filters",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,45 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def consoldidate_individual_deep_search(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ConsolidatedResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
state.entities_types_str
research_object_results = state.research_object_results
consolidated_research_object_results_str = "\n".join(
[f"{x['object']}: {x['results']}" for x in research_object_results]
)
return ConsolidatedResearchUpdate(
consolidated_research_object_results_str=consolidated_research_object_results_str,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate simple sql",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,107 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import ReferenceResults
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.document import get_base_llm_doc_information
from onyx.db.engine import get_session_with_current_tenant
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _general_format(result: dict[str, Any]) -> str:
name = result.get("name")
entity_type_id_name = result.get("entity_type_id_name")
result.get("id_name")
if entity_type_id_name:
return f"{entity_type_id_name.capitalize()}: {name}"
else:
return f"{name}"
def _generate_reference_results(
individualized_query_results: list[dict[str, Any]]
) -> ReferenceResults:
"""
Generate reference results from the query results data string.
"""
citations: list[str] = []
general_entities = []
# get all entities that correspond to an Onu=yx document
document_ids: list[str] = [
cast(str, x.get("document_id"))
for x in individualized_query_results
if x.get("document_id")
]
with get_session_with_current_tenant() as session:
llm_doc_information_results = get_base_llm_doc_information(
session, document_ids
)
for llm_doc_information_result in llm_doc_information_results:
citations.append(llm_doc_information_result.center_chunk.semantic_identifier)
for result in individualized_query_results:
document_id: str | None = result.get("document_id")
if not document_id:
general_entities.append(_general_format(result))
return ReferenceResults(citations=citations, general_entities=general_entities)
def process_kg_only_answers(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ResultsDataUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
query_results = state.query_results
individualized_query_results = state.individualized_query_results
query_results_list = []
if query_results:
for query_result in query_results:
query_results_list.append(str(query_result).replace(":", ": ").capitalize())
else:
raise ValueError("No query results were found")
query_results_data_str = "\n".join(query_results_list)
if individualized_query_results:
reference_results = _generate_reference_results(individualized_query_results)
else:
reference_results = None
return ResultsDataUpdate(
query_results_data_str=query_results_data_str,
individualized_query_results_data_str="",
reference_results=reference_results,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="kg query results data processing",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,157 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.states import MainOutput
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_EXAMPLES_PROMPT
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def generate_answer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> MainOutput:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
state.entities_types_str
introductory_answer = state.query_results_data_str
state.reference_results
search_tool = graph_config.tooling.search_tool
if search_tool is None:
raise ValueError("Search tool is not set")
consolidated_research_object_results_str = (
state.consolidated_research_object_results_str
)
question = graph_config.inputs.search_request.query
output_format = state.output_format
if state.reference_results:
examples = (
state.reference_results.citations
or state.reference_results.general_entities
or []
)
research_results = "\n".join([f"- {example}" for example in examples])
elif consolidated_research_object_results_str:
research_results = consolidated_research_object_results_str
else:
research_results = ""
if research_results and introductory_answer:
output_format_prompt = (
OUTPUT_FORMAT_PROMPT.replace("---question---", question)
.replace("---introductory_answer---", introductory_answer)
.replace("---output_format---", str(output_format) if output_format else "")
.replace("---research_results---", research_results)
)
elif not research_results and introductory_answer:
output_format_prompt = (
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
.replace("---introductory_answer---", introductory_answer)
.replace("---output_format---", str(output_format) if output_format else "")
)
elif research_results and not introductory_answer:
output_format_prompt = (
OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT.replace("---question---", question)
.replace("---output_format---", str(output_format) if output_format else "")
.replace("---research_results---", research_results)
)
else:
raise ValueError("No research results or introductory answer provided")
msg = [
HumanMessage(
content=output_format_prompt,
)
]
fast_llm = graph_config.tooling.fast_llm
dispatch_timings: list[float] = []
response: list[str] = []
def stream_answer() -> list[str]:
for message in fast_llm.stream(
prompt=msg,
timeout_override=30,
max_tokens=1000,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
# logger.debug(f"Answer piece: {content}")
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response
try:
response = run_with_timeout(
30,
stream_answer,
)
except Exception as e:
raise ValueError(f"Could not generate the answer. Error {e}")
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_ANSWER,
level=0,
level_question_num=0,
)
write_custom_event("stream_finished", stop_event, writer)
dispatch_main_answer_stop_info(0, writer)
return MainOutput(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="query completed",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,42 @@
from datetime import datetime
from typing import cast
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
def research(
question: str,
search_tool: SearchTool,
document_sources: list[DocumentSource] | None = None,
time_cutoff: datetime | None = None,
) -> list[LlmDoc]:
# new db session to avoid concurrency issues
callback_container: list[list[InferenceSection]] = []
retrieved_docs: list[LlmDoc] = []
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=question,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
document_sources=document_sources,
time_cutoff=time_cutoff,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
break
return retrieved_docs

View File

@@ -0,0 +1,127 @@
from enum import Enum
from operator import add
from typing import Annotated
from typing import Any
from typing import Dict
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import ReferenceResults
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
class KGVespaFilterResults(BaseModel):
entity_filters: list[str]
relationship_filters: list[str]
class KGAnswerStrategy(Enum):
DEEP = "DEEP"
SIMPLE = "SIMPLE"
class KGAnswerFormat(Enum):
LIST = "LIST"
TEXT = "TEXT"
class YesNoEnum(str, Enum):
YES = "yes"
NO = "no"
class AnalysisUpdate(LoggerUpdate):
normalized_core_entities: list[str] = []
normalized_core_relationships: list[str] = []
query_graph_entities: list[str] = []
query_graph_relationships: list[str] = []
normalized_terms: list[str] = []
normalized_time_filter: str | None = None
strategy: KGAnswerStrategy | None = None
output_format: KGAnswerFormat | None = None
broken_down_question: str | None = None
divide_and_conquer: YesNoEnum | None = None
class SQLSimpleGenerationUpdate(LoggerUpdate):
sql_query: str | None = None
query_results: list[Dict[Any, Any]] | None = None
individualized_sql_query: str | None = None
individualized_query_results: list[Dict[Any, Any]] | None = None
class ConsolidatedResearchUpdate(LoggerUpdate):
consolidated_research_object_results_str: str | None = None
class DeepSearchFilterUpdate(LoggerUpdate):
vespa_filter_results: KGVespaFilterResults | None = None
div_con_entities: list[str] | None = None
source_division: bool | None = None
class ResearchObjectOutput(LoggerUpdate):
research_object_results: Annotated[list[dict[str, Any]], add] = []
class ERTExtractionUpdate(LoggerUpdate):
entities_types_str: str = ""
entities: list[str] = []
relationships: list[str] = []
terms: list[str] = []
time_filter: str | None = None
class ResultsDataUpdate(LoggerUpdate):
query_results_data_str: str | None = None
individualized_query_results_data_str: str | None = None
reference_results: ReferenceResults | None = None
class ResearchObjectUpdate(LoggerUpdate):
research_object_results: Annotated[list[dict[str, Any]], add] = []
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
ERTExtractionUpdate,
AnalysisUpdate,
SQLSimpleGenerationUpdate,
ResultsDataUpdate,
ResearchObjectOutput,
DeepSearchFilterUpdate,
ResearchObjectUpdate,
ConsolidatedResearchUpdate,
):
pass
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]
class ResearchObjectInput(LoggerUpdate):
entity: str
broken_down_question: str
vespa_filter_results: KGVespaFilterResults
source_division: bool | None

View File

@@ -12,12 +12,16 @@ from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
divide_and_conquer_graph_builder,
)
from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
from onyx.agents.agent_search.dc_search_analysis.graph_builder import dc_graph_builder
from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
from onyx.agents.agent_search.deep_search.main.graph_builder import (
main_graph_builder as main_graph_builder_a,
)
from onyx.agents.agent_search.deep_search.main.states import (
MainInput as MainInput,
)
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.chat.models import AgentAnswerPiece
@@ -86,7 +90,7 @@ def _parse_agent_event(
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
graph_input: BasicInput | MainInput | DCMainInput,
graph_input: BasicInput | MainInput | KBMainInput | DCMainInput,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
for event in compiled_graph.stream(
@@ -100,7 +104,7 @@ def manage_sync_streaming(
def run_graph(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
input: BasicInput | MainInput | DCMainInput,
input: BasicInput | MainInput | KBMainInput | DCMainInput,
) -> AnswerStream:
config.behavior.perform_initial_search_decomposition = (
INITIAL_SEARCH_DECOMPOSITION_ENABLED
@@ -150,6 +154,15 @@ def run_basic_graph(
return run_graph(compiled_graph, config, input)
def run_kb_graph(
config: GraphConfig,
) -> AnswerStream:
graph = kb_graph_builder()
compiled_graph = graph.compile()
input = KBMainInput(log_messages=[])
return run_graph(compiled_graph, config, input)
def run_dc_graph(
config: GraphConfig,
) -> AnswerStream:

View File

@@ -27,6 +27,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
logger = setup_logger()
def build_sub_question_answer_prompt(
question: str,

View File

@@ -159,3 +159,9 @@ BaseMessage_Content = str | list[str | dict[str, Any]]
class QueryExpansionType(Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
class ReferenceResults(BaseModel):
# citations: list[InferenceSection]
citations: list[str]
general_entities: list[str]

View File

@@ -11,6 +11,8 @@ from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_dc_graph
from onyx.agents.agent_search.run_graph import run_kb_graph
from onyx.agents.agent_search.run_graph import run_dc_graph
from onyx.agents.agent_search.run_graph import run_main_graph
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
@@ -145,13 +147,21 @@ class Answer:
if self.graph_config.behavior.use_agentic_search:
run_langgraph = run_main_graph
elif (
elif self.graph_config.inputs.search_request.persona:
if (
self.graph_config.inputs.search_request.persona
and self.graph_config.inputs.search_request.persona.description.startswith(
"DivCon Beta Agent"
"DivCon Beta Agent"
)
):
run_langgraph = run_dc_graph
):
run_langgraph = run_dc_graph
elif self.graph_config.inputs.search_request.persona.name.startswith(
"KG Dev"
):
run_langgraph = run_kb_graph
else:
run_langgraph = run_basic_graph
else:
run_langgraph = run_basic_graph

View File

@@ -13,6 +13,7 @@ from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import MessageType
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
@@ -100,6 +101,39 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
)
def inference_section_from_llm_doc(llm_doc: LlmDoc) -> InferenceSection:
# Create a center chunk first
center_chunk = InferenceChunk(
document_id=llm_doc.document_id,
chunk_id=0, # Default to 0 since LlmDoc doesn't have this info
content=llm_doc.content,
blurb=llm_doc.blurb,
semantic_identifier=llm_doc.semantic_identifier,
source_type=llm_doc.source_type,
metadata=llm_doc.metadata,
updated_at=llm_doc.updated_at,
source_links=llm_doc.source_links or {},
match_highlights=llm_doc.match_highlights or [],
section_continuation=False,
image_file_name=None,
title=None,
boost=1,
recency_bias=1.0,
score=None,
hidden=False,
doc_summary="",
chunk_context="",
)
# Create InferenceSection with the center chunk
# Since we don't have access to the original chunks, we'll use an empty list
return InferenceSection(
center_chunk=center_chunk,
chunks=[], # Original surrounding chunks are not available in LlmDoc
combined_content=llm_doc.content,
)
def combine_message_thread(
messages: list[ThreadMessage],
max_tokens: int | None,

View File

@@ -100,6 +100,8 @@ from onyx.file_store.utils import load_all_chat_files
from onyx.file_store.utils import load_all_user_file_files
from onyx.file_store.utils import load_all_user_files
from onyx.file_store.utils import save_files
from onyx.kg.clustering.clustering import kg_clustering
from onyx.kg.extractions.extraction_processing import kg_extraction
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
@@ -664,6 +666,17 @@ def stream_chat_message_objects(
llm: LLM
index_str = "danswer_chunk_text_embedding_3_small"
if new_msg_req.message == "ee":
kg_extraction(tenant_id, index_str)
raise Exception("Extractions done")
elif new_msg_req.message == "cc":
kg_clustering(tenant_id, index_str)
raise Exception("Clustering done")
try:
# Move these variables inside the try block
file_id_to_user_file = {}

View File

@@ -0,0 +1,12 @@
import json
import os
KG_OWN_EMAIL_DOMAINS: list[str] = json.loads(
os.environ.get("KG_OWN_EMAIL_DOMAINS", "[]")
) # must be list
KG_IGNORE_EMAIL_DOMAINS: list[str] = json.loads(
os.environ.get("KG_IGNORE_EMAIL_DOMAINS", "[]")
) # must be list
KG_OWN_COMPANY: str = os.environ.get("KG_OWN_COMPANY", "")

View File

@@ -113,6 +113,9 @@ class BaseFilters(BaseModel):
tags: list[Tag] | None = None
user_file_ids: list[int] | None = None
user_folder_ids: list[int] | None = None
kg_entities: list[str] | None = None
kg_relationships: list[str] | None = None
kg_terms: list[str] | None = None
class IndexFilters(BaseFilters):

View File

@@ -68,6 +68,7 @@ class SearchPipeline:
):
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
self.search_request = search_request
self.user = user
self.llm = llm

View File

@@ -182,6 +182,9 @@ def retrieval_preprocessing(
tags=preset_filters.tags, # Tags are never auto-extracted
access_control_list=user_acl_filters,
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
kg_entities=preset_filters.kg_entities,
kg_relationships=preset_filters.kg_relationships,
kg_terms=preset_filters.kg_terms,
)
llm_evaluation_type = LLMEvaluationType.BASIC

View File

@@ -349,6 +349,8 @@ def retrieve_chunks(
) -> list[InferenceChunk]:
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
logger.info(f"RETRIEVAL CHUNKS query: {query}")
multilingual_expansion = get_multilingual_expansion(db_session)
# Don't do query expansion on complex queries, rephrasings likely would not work well
if not multilingual_expansion or "\n" in query.query or "\r" in query.query:

View File

@@ -334,3 +334,21 @@ def mark_ccpair_with_indexing_trigger(
except Exception:
db_session.rollback()
raise
def get_unprocessed_connector_ids(db_session: Session) -> list[int]:
"""
Retrieves a list of connector IDs that have not been KG processed for a given tenant.
Args:
db_session (Session): The database session to use
Returns:
list[int]: List of connector IDs that have enabled KG extraction but have unprocessed documents
"""
try:
stmt = select(Connector.id).where(Connector.kg_extraction_enabled)
result = db_session.execute(stmt)
return [row[0] for row in result.fetchall()]
except Exception as e:
logger.error(f"Error fetching unprocessed connector IDs: {str(e)}")
raise e

View File

@@ -23,6 +23,8 @@ from sqlalchemy.sql.expression import null
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
@@ -377,6 +379,7 @@ def upsert_documents(
last_modified=datetime.now(timezone.utc),
primary_owners=doc.primary_owners,
secondary_owners=doc.secondary_owners,
kg_processed=False,
)
)
for doc in seen_documents.values()
@@ -843,3 +846,213 @@ def fetch_chunk_count_for_document(
) -> int | None:
stmt = select(DbDocument.chunk_count).where(DbDocument.id == document_id)
return db_session.execute(stmt).scalar_one_or_none()
def get_unprocessed_kg_documents_for_connector(
db_session: Session,
connector_id: int,
batch_size: int = 100,
) -> Generator[DbDocument, None, None]:
"""
Retrieves all documents associated with a connector that have not yet been processed
for knowledge graph extraction. Uses a generator pattern to handle large result sets.
Args:
db_session (Session): The database session to use
connector_id (int): The ID of the connector to check
batch_size (int): Number of documents to fetch per batch, defaults to 100
Yields:
DbDocument: Documents that haven't been KG processed, one at a time
"""
offset = 0
while True:
stmt = (
select(DbDocument)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
or_(
DocumentByConnectorCredentialPair.has_been_kg_processed.is_(
None
),
DocumentByConnectorCredentialPair.has_been_kg_processed.is_(
False
),
),
)
)
.distinct()
.limit(batch_size)
.offset(offset)
)
batch = list(db_session.scalars(stmt).all())
if not batch:
break
for document in batch:
yield document
offset += batch_size
def get_kg_processed_document_ids(db_session: Session) -> list[str]:
"""
Retrieves all document IDs where kg_processed is True.
Args:
db_session (Session): The database session to use
Returns:
list[str]: List of document IDs that have been KG processed
"""
stmt = select(DbDocument.id).where(DbDocument.kg_processed.is_(True))
return list(db_session.scalars(stmt).all())
def update_document_kg_info(
db_session: Session,
document_id: str,
kg_processed: bool,
kg_data: dict,
) -> None:
"""Updates the knowledge graph related information for a document.
Args:
db_session (Session): The database session to use
document_id (str): The ID of the document to update
kg_processed (bool): Whether the document has been processed for KG extraction
kg_data (dict): Dictionary containing KG data with 'entities', 'relationships', and 'terms' keys
Raises:
ValueError: If the document with the given ID is not found
"""
stmt = (
update(DbDocument)
.where(DbDocument.id == document_id)
.values(
kg_processed=kg_processed,
kg_data=kg_data,
)
)
db_session.execute(stmt)
def get_document_kg_info(
db_session: Session,
document_id: str,
) -> tuple[bool, dict] | None:
"""Retrieves the knowledge graph processing status and data for a document.
Args:
db_session (Session): The database session to use
document_id (str): The ID of the document to query
Returns:
Optional[Tuple[bool, dict]]: A tuple containing:
- bool: Whether the document has been KG processed
- dict: The KG data containing 'entities', 'relationships', and 'terms'
Returns None if the document is not found
"""
stmt = select(DbDocument.kg_processed, DbDocument.kg_data).where(
DbDocument.id == document_id
)
result = db_session.execute(stmt).one_or_none()
if result is None:
return None
return result.kg_processed, result.kg_data or {}
def get_all_kg_processed_documents_info(
db_session: Session,
) -> list[tuple[str, dict]]:
"""Retrieves the knowledge graph data for all documents that have been processed.
Args:
db_session (Session): The database session to use
Returns:
List[Tuple[str, dict]]: A list of tuples containing:
- str: The document ID
- dict: The KG data containing 'entities', 'relationships', and 'terms'
Only returns documents where kg_processed is True
"""
stmt = (
select(DbDocument.id, DbDocument.kg_data)
.where(DbDocument.kg_processed.is_(True))
.order_by(DbDocument.id)
)
results = db_session.execute(stmt).all()
return [(str(doc_id), kg_data or {}) for doc_id, kg_data in results]
def get_base_llm_doc_information(
db_session: Session, document_ids: list[str]
) -> list[InferenceSection]:
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
results = db_session.execute(stmt).all()
inference_sections = []
for doc in results:
bare_doc = doc[0]
inference_section = InferenceSection(
center_chunk=InferenceChunk(
document_id=bare_doc.id,
chunk_id=0,
source_type=DocumentSource.NOT_APPLICABLE,
semantic_identifier=bare_doc.semantic_id,
title=None,
boost=0,
recency_bias=0,
score=0,
hidden=False,
metadata={},
blurb="",
content="",
source_links=None,
image_file_name=None,
section_continuation=False,
match_highlights=[],
doc_summary="",
chunk_context="",
updated_at=None,
),
chunks=[],
combined_content="",
)
inference_sections.append(inference_section)
return inference_sections
def get_document_updated_at(
document_id: str,
db_session: Session,
) -> datetime | None:
"""Retrieves the doc_updated_at timestamp for a given document ID.
Args:
document_id (str): The ID of the document to query
db_session (Session): The database session to use
Returns:
Optional[datetime]: The doc_updated_at timestamp if found, None if document doesn't exist
"""
if len(document_id.split(":")) == 2:
document_id = document_id.split(":")[1]
elif len(document_id.split(":")) > 2:
raise ValueError(f"Invalid document ID: {document_id}")
else:
pass
stmt = select(DbDocument.doc_updated_at).where(DbDocument.id == document_id)
return db_session.execute(stmt).scalar_one_or_none()

271
backend/onyx/db/entities.py Normal file
View File

@@ -0,0 +1,271 @@
from datetime import datetime
from typing import List
from sqlalchemy import literal_column
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session
from onyx.db.models import KGEntity
from onyx.db.models import KGEntityType
def get_entity_types(
db_session: Session,
active: bool | None = True,
) -> list[KGEntityType]:
# Query the database for all distinct entity types
if active is None:
return db_session.query(KGEntityType).order_by(KGEntityType.id_name).all()
else:
return (
db_session.query(KGEntityType)
.filter(KGEntityType.active == active)
.order_by(KGEntityType.id_name)
.all()
)
def add_entity(
db_session: Session,
entity_type: str,
name: str,
document_id: str | None = None,
cluster_count: int = 0,
event_time: datetime | None = None,
) -> "KGEntity | None":
"""Add a new entity to the database.
Args:
db_session: SQLAlchemy session
entity_type: Type of the entity (must match an existing KGEntityType)
name: Name of the entity
cluster_count: Number of clusters this entity has been found
Returns:
KGEntity: The created entity
"""
entity_type = entity_type.upper()
name = name.title()
id_name = f"{entity_type}:{name}"
# Create new entity
stmt = (
pg_insert(KGEntity)
.values(
id_name=id_name,
entity_type_id_name=entity_type,
document_id=document_id,
name=name,
cluster_count=cluster_count,
event_time=event_time,
)
.on_conflict_do_update(
index_elements=["id_name"],
set_=dict(
# Direct numeric addition without text()
cluster_count=KGEntity.cluster_count
+ literal_column("EXCLUDED.cluster_count"),
# Keep other fields updated as before
entity_type_id_name=entity_type,
document_id=document_id,
name=name,
event_time=event_time,
),
)
.returning(KGEntity)
)
result = db_session.execute(stmt).scalar()
return result
def get_kg_entity_by_document(db: Session, document_id: str) -> KGEntity | None:
"""
Check if a document_id exists in the kg_entities table and return its id_name if found.
Args:
db: SQLAlchemy database session
document_id: The document ID to search for
Returns:
The id_name of the matching KGEntity if found, None otherwise
"""
query = select(KGEntity).where(KGEntity.document_id == document_id)
result = db.execute(query).scalar()
return result
def get_ungrounded_entities(db_session: Session) -> List[KGEntity]:
"""Get all entities whose entity type has grounding = 'UE' (ungrounded entities).
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntity objects belonging to ungrounded entity types
"""
return (
db_session.query(KGEntity)
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
.filter(KGEntityType.grounding == "UE")
.all()
)
def get_grounded_entities(db_session: Session) -> List[KGEntity]:
"""Get all entities whose entity type has grounding = 'UE' (ungrounded entities).
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntity objects belonging to ungrounded entity types
"""
return (
db_session.query(KGEntity)
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
.filter(KGEntityType.grounding == "GE")
.all()
)
def get_determined_grounded_entity_types(db_session: Session) -> List[KGEntityType]:
"""Get all entity types that have non-null ge_determine_instructions.
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntityType objects that have ge_determine_instructions defined
"""
return (
db_session.query(KGEntityType)
.filter(KGEntityType.ge_determine_instructions.isnot(None))
.all()
)
def get_entity_types_with_grounded_source_name(
db_session: Session,
) -> List[KGEntityType]:
"""Get all entity types that have non-null grounded_source_name.
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntityType objects that have grounded_source_name defined
"""
return (
db_session.query(KGEntityType)
.filter(KGEntityType.grounded_source_name.isnot(None))
.all()
)
def get_entity_types_with_grounding_signature(
db_session: Session,
) -> List[KGEntityType]:
"""Get all entity types that have non-null ge_grounding_signature.
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntityType objects that have ge_grounding_signature defined
"""
return (
db_session.query(KGEntityType)
.filter(KGEntityType.ge_grounding_signature.isnot(None))
.all()
)
def get_ge_entities_by_types(
db_session: Session, entity_types: List[str]
) -> List[KGEntity]:
"""Get all entities matching an entity_type.
Args:
db_session: SQLAlchemy session
entity_types: List of entity types to filter by
Returns:
List of KGEntity objects belonging to the specified entity types
"""
return (
db_session.query(KGEntity)
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
.filter(KGEntity.entity_type_id_name.in_(entity_types))
.filter(KGEntityType.grounding == "GE")
.all()
)
def delete_entities_by_id_names(db_session: Session, id_names: list[str]) -> int:
"""
Delete entities from the database based on a list of id_names.
Args:
db_session: SQLAlchemy database session
id_names: List of entity id_names to delete
Returns:
Number of entities deleted
"""
deleted_count = (
db_session.query(KGEntity)
.filter(KGEntity.id_name.in_(id_names))
.delete(synchronize_session=False)
)
db_session.flush() # Flush to ensure deletion is processed
return deleted_count
def get_entities_for_types(
db_session: Session, entity_types: List[str]
) -> List[KGEntity]:
"""Get all entities that belong to the specified entity types.
Args:
db_session: SQLAlchemy session
entity_types: List of entity type id_names to filter by
Returns:
List of KGEntity objects belonging to the specified entity types
"""
return (
db_session.query(KGEntity)
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
.filter(KGEntity.entity_type_id_name.in_(entity_types))
.all()
)
def get_entity_type_by_grounded_source_name(
db_session: Session, grounded_source_name: str
) -> KGEntityType | None:
"""Get an entity type by its grounded_source_name and return it as a dictionary.
Args:
db_session: SQLAlchemy session
grounded_source_name: The grounded_source_name of the entity to retrieve
Returns:
Dictionary containing the entity's data with column names as keys,
or None if the entity is not found
"""
entity_type = (
db_session.query(KGEntityType)
.filter(KGEntityType.grounded_source_name == grounded_source_name)
.first()
)
if entity_type is None:
return None
return entity_type

View File

@@ -53,6 +53,7 @@ from onyx.db.enums import (
SyncType,
SyncStatus,
)
from onyx.configs.constants import NotificationType
from onyx.configs.constants import SearchFeedbackType
from onyx.configs.constants import TokenRateLimitScope
@@ -586,6 +587,22 @@ class Document(Base):
)
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
# tables for the knowledge graph data
kg_processed: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
comment="Whether this document has been processed for knowledge graph extraction",
)
kg_data: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Knowledge graph data extracted from this document",
)
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
"DocumentRetrievalFeedback", back_populates="document"
)
@@ -604,6 +621,304 @@ class Document(Base):
)
class KGEntityType(Base):
__tablename__ = "kg_entity_type"
# Primary identifier
id_name: Mapped[str] = mapped_column(
String, primary_key=True, nullable=False, index=True
)
description: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
grounding: Mapped[str] = mapped_column(
NullFilteredString, nullable=False, index=False
)
grounded_source_name: Mapped[str] = mapped_column(
NullFilteredString, nullable=False, index=False
)
ge_determine_instructions: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=True, default=None
)
ge_grounding_signature: Mapped[str] = mapped_column(
NullFilteredString, nullable=True, index=False, default=None
)
clustering: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Clustering information for this entity type",
)
classification_requirements: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Pre-extraction classification requirements and instructions",
)
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
extraction_sources: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
comment="Sources and methods used to extract this entity",
)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
class KGRelationshipType(Base):
__tablename__ = "kg_relationship_type"
# Primary identifier
id_name: Mapped[str] = mapped_column(
NullFilteredString,
primary_key=True,
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
source_entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
target_entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
definition: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
comment="Whether this relationship type represents a definition",
)
clustering: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Clustering information for this relationship type",
)
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
# Relationships to EntityType
source_type: Mapped["KGEntityType"] = relationship(
"KGEntityType",
foreign_keys=[source_entity_type_id_name],
backref="source_relationship_type",
)
target_type: Mapped["KGEntityType"] = relationship(
"KGEntityType",
foreign_keys=[target_entity_type_id_name],
backref="target_relationship_type",
)
class KGEntity(Base):
__tablename__ = "kg_entity"
# Primary identifier
id_name: Mapped[str] = mapped_column(
NullFilteredString, primary_key=True, index=True
)
# Basic entity information
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
document_id: Mapped[str | None] = mapped_column(
NullFilteredString, nullable=True, index=True
)
alternative_names: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Reference to KGEntityType
entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
# Relationship to KGEntityType
entity_type: Mapped["KGEntityType"] = relationship("KGEntityType", backref="entity")
description: Mapped[str | None] = mapped_column(String, nullable=True)
keywords: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Access control
acl: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Boosts - using JSON for flexibility
boosts: Mapped[dict] = mapped_column(postgresql.JSONB, nullable=False, default=dict)
event_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
comment="Time of the event being processed",
)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
__table_args__ = (
# Fixed column names in indexes
Index("ix_entity_type_acl", entity_type_id_name, acl),
Index("ix_entity_name_search", name, entity_type_id_name),
)
class KGRelationship(Base):
__tablename__ = "kg_relationship"
# Primary identifier
id_name: Mapped[str] = mapped_column(
NullFilteredString, primary_key=True, index=True
)
# Source and target nodes (foreign keys to Entity table)
source_node: Mapped[str] = mapped_column(
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
)
target_node: Mapped[str] = mapped_column(
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
)
# Relationship type
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
# Add new relationship type reference
relationship_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_relationship_type.id_name"),
nullable=False,
index=True,
)
# Add the SQLAlchemy relationship property
relationship_type: Mapped["KGRelationshipType"] = relationship(
"KGRelationshipType", backref="relationship"
)
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
# Relationships to Entity table
source: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[source_node])
target: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[target_node])
__table_args__ = (
# Index for querying relationships by type
Index("ix_kg_relationship_type", type),
# Composite index for source/target queries
Index("ix_kg_relationship_nodes", source_node, target_node),
# Ensure unique relationships between nodes of a specific type
UniqueConstraint(
"source_node",
"target_node",
"type",
name="uq_kg_relationship_source_target_type",
),
)
class KGTerm(Base):
__tablename__ = "kg_term"
# Make id_term the primary key
id_term: Mapped[str] = mapped_column(
NullFilteredString, primary_key=True, nullable=False, index=True
)
# List of entity types this term applies to
entity_types: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
__table_args__ = (
# Index for searching terms with specific entity types
Index("ix_search_term_entities", entity_types),
# Index for term lookups
Index("ix_search_term_term", id_term),
)
class ChunkStats(Base):
__tablename__ = "chunk_stats"
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
@@ -691,6 +1006,14 @@ class Connector(Base):
indexing_start: Mapped[datetime.datetime | None] = mapped_column(
DateTime, nullable=True
)
kg_extraction_enabled: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
comment="Whether this connector should extract knowledge graph entities",
)
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
time_created: Mapped[datetime.datetime] = mapped_column(
@@ -1134,6 +1457,12 @@ class DocumentByConnectorCredentialPair(Base):
# the actual indexing is complete
has_been_indexed: Mapped[bool] = mapped_column(Boolean)
has_been_kg_processed: Mapped[bool | None] = mapped_column(
Boolean,
nullable=True,
comment="Whether this document has been processed for knowledge graph extraction",
)
connector: Mapped[Connector] = relationship(
"Connector", back_populates="documents_by_connector", passive_deletes=True
)

View File

@@ -0,0 +1,298 @@
from typing import List
from sqlalchemy import or_
from sqlalchemy.orm import Session
from onyx.db.models import KGRelationship
from onyx.db.models import KGRelationshipType
from onyx.kg.utils.formatting_utils import format_entity
from onyx.kg.utils.formatting_utils import format_relationship
from onyx.kg.utils.formatting_utils import generate_relationship_type
def add_relationship(
db_session: Session,
relationship_id_name: str,
cluster_count: int | None = None,
) -> "KGRelationship":
"""
Add a relationship between two entities to the database.
Args:
db_session: SQLAlchemy database session
source_entity_id: ID of the source entity
relationship_type: Type of relationship
target_entity_id: ID of the target entity
cluster_count: Optional count of similar relationships clustered together
Returns:
The created KGRelationship object
Raises:
sqlalchemy.exc.IntegrityError: If the relationship already exists or entities don't exist
"""
# Generate a unique ID for the relationship
(
source_entity_id_name,
relationship_string,
target_entity_id_name,
) = relationship_id_name.split("__")
source_entity_id_name = format_entity(source_entity_id_name)
target_entity_id_name = format_entity(target_entity_id_name)
relationship_id_name = format_relationship(relationship_id_name)
relationship_type = generate_relationship_type(relationship_id_name)
# Create new relationship
relationship = KGRelationship(
id_name=relationship_id_name,
source_node=source_entity_id_name,
target_node=target_entity_id_name,
type=relationship_string.lower(),
relationship_type_id_name=relationship_type,
cluster_count=cluster_count,
)
db_session.add(relationship)
db_session.flush() # Flush to get any DB errors early
return relationship
def add_or_increment_relationship(
db_session: Session,
relationship_id_name: str,
) -> "KGRelationship":
"""
Add a relationship between two entities to the database if it doesn't exist,
or increment its cluster_count by 1 if it already exists.
Args:
db_session: SQLAlchemy database session
relationship_id_name: The ID name of the relationship in format "source__relationship__target"
Returns:
The created or updated KGRelationship object
Raises:
sqlalchemy.exc.IntegrityError: If there's an error with the database operation
"""
# Format the relationship_id_name
relationship_id_name = format_relationship(relationship_id_name)
# Check if the relationship already exists
existing_relationship = (
db_session.query(KGRelationship)
.filter(KGRelationship.id_name == relationship_id_name)
.first()
)
if existing_relationship:
# If it exists, increment the cluster_count
existing_relationship.cluster_count = (
existing_relationship.cluster_count or 0
) + 1
db_session.flush()
return existing_relationship
else:
# If it doesn't exist, add it with cluster_count=1
return add_relationship(db_session, relationship_id_name, cluster_count=1)
def add_relationship_type(
db_session: Session,
source_entity_type: str,
relationship_type: str,
target_entity_type: str,
definition: bool = False,
extraction_count: int = 0,
) -> "KGRelationshipType":
"""
Add a new relationship type to the database.
Args:
db_session: SQLAlchemy session
source_entity_type: Type of the source entity
relationship_type: Type of relationship
target_entity_type: Type of the target entity
definition: Whether this relationship type represents a definition (default False)
Returns:
The created KGRelationshipType object
Raises:
sqlalchemy.exc.IntegrityError: If the relationship type already exists
"""
id_name = f"{source_entity_type.upper()}__{relationship_type}__{target_entity_type.upper()}"
# Create new relationship type
rel_type = KGRelationshipType(
id_name=id_name,
name=relationship_type,
source_entity_type_id_name=source_entity_type.upper(),
target_entity_type_id_name=target_entity_type.upper(),
definition=definition,
cluster_count=extraction_count,
type=relationship_type, # Using the relationship_type as the type
active=True, # Setting as active by default
)
db_session.add(rel_type)
db_session.flush() # Flush to get any DB errors early
return rel_type
def get_all_relationship_types(db_session: Session) -> list["KGRelationshipType"]:
"""
Retrieve all relationship types from the database.
Args:
db_session: SQLAlchemy database session
Returns:
List of KGRelationshipType objects
"""
return db_session.query(KGRelationshipType).all()
def get_all_relationships(db_session: Session) -> list["KGRelationship"]:
"""
Retrieve all relationships from the database.
Args:
db_session: SQLAlchemy database session
Returns:
List of KGRelationship objects
"""
return db_session.query(KGRelationship).all()
def delete_relationships_by_id_names(db_session: Session, id_names: list[str]) -> int:
"""
Delete relationships from the database based on a list of id_names.
Args:
db_session: SQLAlchemy database session
id_names: List of relationship id_names to delete
Returns:
Number of relationships deleted
Raises:
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
"""
deleted_count = (
db_session.query(KGRelationship)
.filter(KGRelationship.id_name.in_(id_names))
.delete(synchronize_session=False)
)
db_session.flush() # Flush to ensure deletion is processed
return deleted_count
def delete_relationship_types_by_id_names(
db_session: Session, id_names: list[str]
) -> int:
"""
Delete relationship types from the database based on a list of id_names.
Args:
db_session: SQLAlchemy database session
id_names: List of relationship type id_names to delete
Returns:
Number of relationship types deleted
Raises:
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
"""
deleted_count = (
db_session.query(KGRelationshipType)
.filter(KGRelationshipType.id_name.in_(id_names))
.delete(synchronize_session=False)
)
db_session.flush() # Flush to ensure deletion is processed
return deleted_count
def get_relationships_for_entity_type_pairs(
db_session: Session, entity_type_pairs: list[tuple[str, str]]
) -> list["KGRelationshipType"]:
"""
Get relationship types from the database based on a list of entity type pairs.
Args:
db_session: SQLAlchemy database session
entity_type_pairs: List of tuples where each tuple contains (source_entity_type, target_entity_type)
Returns:
List of KGRelationshipType objects where source and target types match the provided pairs
"""
conditions = [
(
(KGRelationshipType.source_entity_type_id_name == source_type)
& (KGRelationshipType.target_entity_type_id_name == target_type)
)
for source_type, target_type in entity_type_pairs
]
return db_session.query(KGRelationshipType).filter(or_(*conditions)).all()
def get_allowed_relationship_type_pairs(
db_session: Session, entities: list[str]
) -> list[str]:
"""
Get the allowed relationship pairs for the given entities.
Args:
db_session: SQLAlchemy database session
entities: List of entity type ID names to filter by
Returns:
List of id_names from KGRelationshipType where both source and target entity types
are in the provided entities list
"""
entity_types = list(set([entity.split(":")[0] for entity in entities]))
return [
row[0]
for row in (
db_session.query(KGRelationshipType.id_name)
.filter(KGRelationshipType.source_entity_type_id_name.in_(entity_types))
.filter(KGRelationshipType.target_entity_type_id_name.in_(entity_types))
.distinct()
.all()
)
]
def get_relationships_of_entity(db_session: Session, entity_id: str) -> List[str]:
"""Get all relationship ID names where the given entity is either the source or target node.
Args:
db_session: SQLAlchemy session
entity_id: ID of the entity to find relationships for
Returns:
List of relationship ID names where the entity is either source or target
"""
return [
row[0]
for row in (
db_session.query(KGRelationship.id_name)
.filter(
or_(
KGRelationship.source_node == entity_id,
KGRelationship.target_node == entity_id,
)
)
.all()
)
]

View File

@@ -98,7 +98,7 @@ class VespaDocumentFields:
understandable like this for now.
"""
# all other fields except these 4 will always be left alone by the update request
# all other fields except these 4 and knowledge graph will always be left alone by the update request
access: DocumentAccess | None = None
document_sets: set[str] | None = None
boost: float | None = None

View File

@@ -85,6 +85,25 @@ schema DANSWER_CHUNK_NAME {
indexing: attribute
}
# Separate array fields for knowledge graph data
field kg_entities type weightedset<string> {
rank: filter
indexing: summary | attribute
attribute: fast-search
}
field kg_relationships type weightedset<string> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
field kg_terms type weightedset<string> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
# Needs to have a separate Attribute list for efficient filtering
field metadata_list type array<string> {
indexing: summary | attribute

View File

@@ -166,18 +166,19 @@ def _get_chunks_via_visit_api(
# build the list of fields to retrieve
field_set_list = (
None
if not field_names
else [f"{index_name}:{field_name}" for field_name in field_names]
[] if not field_names else [f"{field_name}" for field_name in field_names]
)
acl_fieldset_entry = f"{index_name}:{ACCESS_CONTROL_LIST}"
acl_fieldset_entry = f"{ACCESS_CONTROL_LIST}"
if (
field_set_list
and filters.access_control_list
and acl_fieldset_entry not in field_set_list
):
field_set_list.append(acl_fieldset_entry)
field_set = ",".join(field_set_list) if field_set_list else None
if field_set_list:
field_set = f"{index_name}:" + ",".join(field_set_list)
else:
field_set = None
# build filters
selection = f"{index_name}.document_id=='{chunk_request.document_id}'"

View File

@@ -17,6 +17,7 @@ from uuid import UUID
import httpx # type: ignore
import requests # type: ignore
from pydantic import BaseModel
from retry import retry
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
@@ -29,6 +30,7 @@ from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.document_index_utils import get_document_chunk_ids
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentInsertionRecord
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
@@ -100,6 +102,37 @@ class _VespaUpdateRequest:
update_request: dict[str, dict]
class KGVespaChunkUpdateRequest(BaseModel):
document_id: str
chunk_id: int
url: str
update_request: dict[str, dict]
class KGUChunkUpdateRequest(BaseModel):
"""
Update KG fields for a document
"""
document_id: str
chunk_id: int
core_entity: str
entities: set[str] | None = None
relationships: set[str] | None = None
terms: set[str] | None = None
class KGUDocumentUpdateRequest(BaseModel):
"""
Update KG fields for a document
"""
document_id: str
entities: set[str]
relationships: set[str]
terms: set[str]
def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
@@ -504,6 +537,51 @@ class VespaIndex(DocumentIndex):
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
raise requests.HTTPError(failure_msg) from e
@classmethod
def _apply_kg_chunk_updates_batched(
cls,
updates: list[KGVespaChunkUpdateRequest],
httpx_client: httpx.Client,
batch_size: int = BATCH_SIZE,
) -> None:
"""Runs a batch of updates in parallel via the ThreadPoolExecutor."""
def _kg_update_chunk(
update: KGVespaChunkUpdateRequest, http_client: httpx.Client
) -> httpx.Response:
# logger.debug(
# f"Updating KG with request to {update.url} with body {update.update_request}"
# )
return http_client.put(
update.url,
headers={"Content-Type": "application/json"},
json=update.update_request,
)
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficient for
# indexing / updates / deletes since we have to make a large volume of requests.
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
httpx_client as http_client,
):
for update_batch in batch_generator(updates, batch_size):
future_to_document_id = {
executor.submit(
_kg_update_chunk,
update,
http_client,
): update.document_id
for update in update_batch
}
for future in concurrent.futures.as_completed(future_to_document_id):
res = future.result()
try:
res.raise_for_status()
except requests.HTTPError as e:
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
raise requests.HTTPError(failure_msg) from e
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
@@ -587,6 +665,89 @@ class VespaIndex(DocumentIndex):
time.monotonic() - update_start,
)
def kg_chunk_updates(
self, kg_update_requests: list[KGUChunkUpdateRequest], tenant_id: str
) -> None:
def _get_general_entity(specific_entity: str) -> str:
entity_type, entity_name = specific_entity.split(":")
if entity_type != "*":
return f"{entity_type}:*"
else:
return specific_entity
processed_updates_requests: list[KGVespaChunkUpdateRequest] = []
logger.debug(f"Updating {len(kg_update_requests)} documents in Vespa")
update_start = time.monotonic()
# Build the _VespaUpdateRequest objects
for kg_update_request in kg_update_requests:
kg_update_dict: dict[str, dict] = {"fields": {}}
implied_entities = set()
if kg_update_request.relationships is not None:
for kg_relationship in kg_update_request.relationships:
kg_relationship_split = kg_relationship.split("__")
if len(kg_relationship_split) == 3:
implied_entities.add(kg_relationship_split[0])
implied_entities.add(kg_relationship_split[2])
# Keep this for now in case we want to also add the general entities
# implied_entities.add(_get_general_entity(kg_relationship_split[0]))
# implied_entities.add(_get_general_entity(kg_relationship_split[2]))
kg_update_dict["fields"]["kg_relationships"] = {
"assign": {
kg_relationship: 1
for kg_relationship in kg_update_request.relationships
}
}
if kg_update_request.entities is not None or implied_entities:
if kg_update_request.entities is None:
kg_entities = implied_entities
else:
kg_entities = set(kg_update_request.entities)
kg_entities.update(implied_entities)
kg_update_dict["fields"]["kg_entities"] = {
"assign": {kg_entity: 1 for kg_entity in kg_entities}
}
if kg_update_request.terms is not None:
kg_update_dict["fields"]["kg_terms"] = {
"assign": {kg_term: 1 for kg_term in kg_update_request.terms}
}
if not kg_update_dict["fields"]:
logger.error("Update request received but nothing to update")
continue
doc_chunk_id = get_uuid_from_chunk_info(
document_id=kg_update_request.document_id,
chunk_id=kg_update_request.chunk_id,
tenant_id=tenant_id,
large_chunk_id=None,
)
processed_updates_requests.append(
KGVespaChunkUpdateRequest(
document_id=kg_update_request.document_id,
chunk_id=kg_update_request.chunk_id,
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
update_request=kg_update_dict,
)
)
with self.httpx_client_context as httpx_client:
self._apply_kg_chunk_updates_batched(
processed_updates_requests, httpx_client
)
logger.debug(
"Finished updating Vespa documents in %.2f seconds",
time.monotonic() - update_start,
)
@retry(
tries=3,
delay=1,

View File

@@ -0,0 +1,82 @@
from pydantic import BaseModel
from retry import retry
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
from onyx.document_index.vespa.index import IndexFilters
from onyx.document_index.vespa.index import KGUChunkUpdateRequest
from onyx.document_index.vespa.index import VespaIndex
from onyx.utils.logger import setup_logger
# from backend.onyx.chat.process_message import get_inference_chunks
# from backend.onyx.document_index.vespa.index import VespaIndex
logger = setup_logger()
class KGChunkInfo(BaseModel):
kg_relationships: dict[str, int]
kg_entities: dict[str, int]
kg_terms: dict[str, int]
@retry(tries=3, delay=1, backoff=2)
def get_document_kg_info(
document_id: str,
index_name: str,
filters: IndexFilters | None = None,
) -> dict | None:
"""
Retrieve the kg_info attribute from a Vespa document by its document_id.
Args:
document_id: The unique identifier of the document.
index_name: The name of the Vespa index to query.
filters: Optional access control filters to apply.
Returns:
The kg_info dictionary if found, None otherwise.
"""
# Use the existing visit API infrastructure
kg_doc_info: dict[int, KGChunkInfo] = {}
document_chunks = _get_chunks_via_visit_api(
chunk_request=VespaChunkRequest(document_id=document_id),
index_name=index_name,
filters=filters or IndexFilters(access_control_list=None),
field_names=["kg_relationships", "kg_entities", "kg_terms"],
get_large_chunks=False,
)
for chunk_id, document_chunk in enumerate(document_chunks):
kg_chunk_info = KGChunkInfo(
kg_relationships=document_chunk["fields"].get("kg_relationships", {}),
kg_entities=document_chunk["fields"].get("kg_entities", {}),
kg_terms=document_chunk["fields"].get("kg_terms", {}),
)
kg_doc_info[chunk_id] = kg_chunk_info # TODO: check the chunk id is correct!
return kg_doc_info
@retry(tries=3, delay=1, backoff=2)
def update_kg_chunks_vespa_info(
kg_update_requests: list[KGUChunkUpdateRequest],
index_name: str,
tenant_id: str,
) -> None:
""" """
# Use the existing visit API infrastructure
vespa_index = VespaIndex(
index_name=index_name,
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=False,
multitenant=False,
httpx_client=None,
)
vespa_index.kg_chunk_updates(
kg_update_requests=kg_update_requests, tenant_id=tenant_id
)

View File

@@ -5,7 +5,6 @@ from datetime import timezone
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import CHUNK_ID
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
@@ -67,6 +66,29 @@ def build_vespa_filters(
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
def _build_kg_filter(
kg_entities: list[str] | None,
kg_relationships: list[str] | None,
kg_terms: list[str] | None,
) -> str:
if not kg_entities and not kg_relationships and not kg_terms:
return ""
filter_parts = []
# Process each filter type using the same pattern
for filter_type, values in [
("kg_entities", kg_entities),
("kg_relationships", kg_relationships),
("kg_terms", kg_terms),
]:
if values:
filter_parts.append(
" and ".join(f'({filter_type} contains "{val}") ' for val in values)
)
return f"({' and '.join(filter_parts)}) and "
# Start building the filter string
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
@@ -106,6 +128,13 @@ def build_vespa_filters(
# Time filter
filter_str += _build_time_filter(filters.time_cutoff)
# KG filter
filter_str += _build_kg_filter(
kg_entities=filters.kg_entities,
kg_relationships=filters.kg_relationships,
kg_terms=filters.kg_terms,
)
# Trim trailing " and "
if remove_trailing_and and filter_str.endswith(" and "):
filter_str = filter_str[:-5]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,229 @@
from collections import defaultdict
from typing import Dict
from typing import List
from typing import Optional
import numpy as np
from thefuzz import process # type: ignore
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import get_entities_for_types
from onyx.db.relationships import get_relationships_for_entity_type_pairs
from onyx.kg.models import NormalizedEntities
from onyx.kg.models import NormalizedRelationships
from onyx.kg.models import NormalizedTerms
from onyx.kg.utils.embeddings import encode_string_batch
def _get_existing_normalized_entities(raw_entities: List[str]) -> List[str]:
"""
Get existing normalized entities from the database.
"""
entity_types = list(set([entity.split(":")[0] for entity in raw_entities]))
with get_session_with_current_tenant() as db_session:
entities = get_entities_for_types(db_session, entity_types)
return [entity.id_name for entity in entities]
def _get_existing_normalized_relationships(
raw_relationships: List[str],
) -> Dict[str, Dict[str, List[str]]]:
"""
Get existing normalized relationships from the database.
"""
relationship_type_map: Dict[str, Dict[str, List[str]]] = defaultdict(
lambda: defaultdict(list)
)
relationship_pairs = list(
set(
[
(
relationship.split("__")[0].split(":")[0],
relationship.split("__")[2].split(":")[0],
)
for relationship in raw_relationships
]
)
)
with get_session_with_current_tenant() as db_session:
relationships = get_relationships_for_entity_type_pairs(
db_session, relationship_pairs
)
for relationship in relationships:
relationship_type_map[relationship.source_entity_type_id_name][
relationship.target_entity_type_id_name
].append(relationship.id_name)
return relationship_type_map
def normalize_entities(raw_entities: List[str]) -> NormalizedEntities:
"""
Match each entity against a list of normalized entities using fuzzy matching.
Returns the best matching normalized entity for each input entity.
Args:
entities: List of entity strings to normalize
Returns:
List of normalized entity strings
"""
# Assume this is your predefined list of normalized entities
norm_entities = _get_existing_normalized_entities(raw_entities)
normalized_results: List[str] = []
normalized_map: Dict[str, str | None] = {}
threshold = 80 # Adjust threshold as needed
for entity in raw_entities:
# Find the best match and its score from norm_entities
best_match, score = process.extractOne(entity, norm_entities)
if score >= threshold:
normalized_results.append(best_match)
normalized_map[entity] = best_match
else:
# If no good match found, keep original
normalized_map[entity] = None
return NormalizedEntities(
entities=normalized_results, entity_normalization_map=normalized_map
)
def normalize_relationships(
raw_relationships: List[str], entity_normalization_map: Dict[str, Optional[str]]
) -> NormalizedRelationships:
"""
Normalize relationships using entity mappings and relationship string matching.
Args:
relationships: List of relationships in format "source__relation__target"
entity_normalization_map: Mapping of raw entities to normalized ones (or None)
Returns:
NormalizedRelationships containing normalized relationships and mapping
"""
# Placeholder for normalized relationship structure
nor_relationships = _get_existing_normalized_relationships(raw_relationships)
normalized_rels: List[str] = []
normalization_map: Dict[str, str | None] = {}
for raw_rel in raw_relationships:
# 1. Split and normalize entities
try:
source, rel_string, target = raw_rel.split("__")
except ValueError:
raise ValueError(f"Invalid relationship format: {raw_rel}")
# Check if entities are in normalization map and not None
norm_source = entity_normalization_map.get(source)
norm_target = entity_normalization_map.get(target)
if norm_source is None or norm_target is None:
normalization_map[raw_rel] = None
continue
# 2. Find candidate normalized relationships
candidate_rels = []
norm_source_type = norm_source.split(":")[0]
norm_target_type = norm_target.split(":")[0]
if (
norm_source_type in nor_relationships
and norm_target_type in nor_relationships[norm_source_type]
):
candidate_rels = [
rel.split("__")[1]
for rel in nor_relationships[norm_source_type][norm_target_type]
]
if not candidate_rels:
normalization_map[raw_rel] = None
continue
# 3. Encode and find best match
strings_to_encode = [rel_string] + candidate_rels
vectors = encode_string_batch(strings_to_encode)
# Get raw relation vector and candidate vectors
raw_vector = vectors[0]
candidate_vectors = vectors[1:]
# Calculate dot products
dot_products = np.dot(candidate_vectors, raw_vector)
best_match_idx = np.argmax(dot_products)
# Create normalized relationship
norm_rel = f"{norm_source}__{candidate_rels[best_match_idx]}__{norm_target}"
normalized_rels.append(norm_rel)
normalization_map[raw_rel] = norm_rel
return NormalizedRelationships(
relationships=normalized_rels, relationship_normalization_map=normalization_map
)
def normalize_terms(raw_terms: List[str]) -> NormalizedTerms:
"""
Normalize terms using semantic similarity matching.
Args:
terms: List of terms to normalize
Returns:
NormalizedTerms containing normalized terms and mapping
"""
# # Placeholder for normalized terms - this would typically come from a predefined list
# normalized_term_list = [
# "algorithm",
# "database",
# "software",
# "programming",
# # ... other normalized terms ...
# ]
# normalized_terms: List[str] = []
# normalization_map: Dict[str, str | None] = {}
# if not raw_terms:
# return NormalizedTerms(terms=[], term_normalization_map={})
# # Encode all terms at once for efficiency
# strings_to_encode = raw_terms + normalized_term_list
# vectors = encode_string_batch(strings_to_encode)
# # Split vectors into query terms and candidate terms
# query_vectors = vectors[:len(raw_terms)]
# candidate_vectors = vectors[len(raw_terms):]
# # Calculate similarity for each term
# for i, term in enumerate(raw_terms):
# # Calculate dot products with all candidates
# similarities = np.dot(candidate_vectors, query_vectors[i])
# best_match_idx = np.argmax(similarities)
# best_match_score = similarities[best_match_idx]
# # Use a threshold to determine if the match is good enough
# if best_match_score > 0.7: # Adjust threshold as needed
# normalized_term = normalized_term_list[best_match_idx]
# normalized_terms.append(normalized_term)
# normalization_map[term] = normalized_term
# else:
# # If no good match found, keep original
# normalization_map[term] = None
# return NormalizedTerms(
# terms=normalized_terms,
# term_normalization_map=normalization_map
# )
return NormalizedTerms(
terms=raw_terms, term_normalization_map={term: term for term in raw_terms}
)

View File

@@ -0,0 +1,191 @@
from onyx.configs.kg_configs import KG_IGNORE_EMAIL_DOMAINS
from onyx.configs.kg_configs import KG_OWN_COMPANY
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import get_kg_entity_by_document
from onyx.kg.context_preparations_extraction.models import ContextPreparation
from onyx.kg.context_preparations_extraction.models import (
KGDocumentClassificationPrompt,
)
from onyx.kg.models import KGChunkFormat
from onyx.kg.models import KGClassificationContent
from onyx.kg.utils.formatting_utils import generalize_entities
from onyx.kg.utils.formatting_utils import kg_email_processing
from onyx.prompts.kg_prompts import FIREFLIES_CHUNK_PREPROCESSING_PROMPT
from onyx.prompts.kg_prompts import FIREFLIES_DOCUMENT_CLASSIFICATION_PROMPT
def prepare_llm_content_fireflies(chunk: KGChunkFormat) -> ContextPreparation:
"""
Fireflies - prepare the content for the LLM.
"""
document_id = chunk.document_id
primary_owners = chunk.primary_owners
secondary_owners = chunk.secondary_owners
content = chunk.content
chunk.title.capitalize()
implied_entities = set()
implied_relationships = set()
with get_session_with_current_tenant() as db_session:
core_document = get_kg_entity_by_document(db_session, document_id)
if core_document:
core_document_id_name = core_document.id_name
else:
core_document_id_name = f"FIREFLIES:{document_id}"
# Do we need this here?
implied_entities.add(f"VENDOR:{KG_OWN_COMPANY}")
implied_entities.add(f"{core_document_id_name}")
implied_entities.add("FIREFLIES:*")
implied_relationships.add(
f"VENDOR:{KG_OWN_COMPANY}__relates_neutrally_to__{core_document_id_name}"
)
company_participant_emails = set()
account_participant_emails = set()
for owner in primary_owners + secondary_owners:
assert isinstance(KG_IGNORE_EMAIL_DOMAINS, list)
kg_owner = kg_email_processing(owner)
if any(
domain.lower() in kg_owner.company.lower()
for domain in KG_IGNORE_EMAIL_DOMAINS
):
continue
if kg_owner.employee:
company_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
if kg_owner.name not in implied_entities:
generalized_target_entity = list(
generalize_entities([core_document_id_name])
)[0]
implied_entities.add(f"EMPLOYEE:{kg_owner.name}")
implied_relationships.add(
f"EMPLOYEE:{kg_owner.name}__relates_neutrally_to__{core_document_id_name}"
)
implied_relationships.add(
f"EMPLOYEE:{kg_owner.name}__relates_neutrally_to__{generalized_target_entity}"
)
implied_relationships.add(
f"EMPLOYEE:*__relates_neutrally_to__{core_document_id_name}"
)
implied_relationships.add(
f"EMPLOYEE:*__relates_neutrally_to__{generalized_target_entity}"
)
if kg_owner.company not in implied_entities:
implied_entities.add(f"VENDOR:{kg_owner.company}")
implied_relationships.add(
f"VENDOR:{kg_owner.company}__relates_neutrally_to__{core_document_id_name}"
)
implied_relationships.add(
f"VENDOR:{kg_owner.company}__relates_neutrally_to__{generalized_target_entity}"
)
else:
account_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
if kg_owner.company not in implied_entities:
implied_entities.add(f"ACCOUNT:{kg_owner.company}")
implied_entities.add("ACCOUNT:*")
implied_relationships.add(
f"ACCOUNT:{kg_owner.company}__relates_neutrally_to__{core_document_id_name}"
)
implied_relationships.add(
f"ACCOUNT:*__relates_neutrally_to__{core_document_id_name}"
)
generalized_target_entity = list(
generalize_entities([core_document_id_name])
)[0]
implied_relationships.add(
f"ACCOUNT:*__relates_neutrally_to__{generalized_target_entity}"
)
implied_relationships.add(
f"ACCOUNT:{kg_owner.company}__relates_neutrally_to__{generalized_target_entity}"
)
participant_string = "\n - " + "\n - ".join(company_participant_emails)
account_participant_string = "\n - " + "\n - ".join(account_participant_emails)
llm_context = FIREFLIES_CHUNK_PREPROCESSING_PROMPT.format(
participant_string=participant_string,
account_participant_string=account_participant_string,
content=content,
)
return ContextPreparation(
llm_context=llm_context,
core_entity=core_document_id_name,
implied_entities=list(implied_entities),
implied_relationships=list(implied_relationships),
implied_terms=[],
)
def prepare_llm_document_content_fireflies(
document_classification_content: KGClassificationContent,
category_list: str,
category_definition_string: str,
) -> KGDocumentClassificationPrompt:
"""
Fireflies - prepare prompt for the LLM classification.
"""
prompt = FIREFLIES_DOCUMENT_CLASSIFICATION_PROMPT.format(
beginning_of_call_content=document_classification_content.classification_content,
category_list=category_list,
category_options=category_definition_string,
)
return KGDocumentClassificationPrompt(
llm_prompt=prompt,
)
def get_classification_content_from_fireflies_chunks(
first_num_classification_chunks: list[dict],
) -> str:
"""
Creates a KGClassificationContent object from a list of Fireflies chunks.
"""
assert isinstance(KG_IGNORE_EMAIL_DOMAINS, list)
primary_owners = first_num_classification_chunks[0]["fields"]["primary_owners"]
secondary_owners = first_num_classification_chunks[0]["fields"]["secondary_owners"]
company_participant_emails = set()
account_participant_emails = set()
for owner in primary_owners + secondary_owners:
kg_owner = kg_email_processing(owner)
if any(
domain.lower() in kg_owner.company.lower()
for domain in KG_IGNORE_EMAIL_DOMAINS
):
continue
if kg_owner.employee:
company_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
else:
account_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
participant_string = "\n - " + "\n - ".join(company_participant_emails)
account_participant_string = "\n - " + "\n - ".join(account_participant_emails)
title_string = first_num_classification_chunks[0]["fields"]["title"]
content_string = "\n".join(
[
chunk_content["fields"]["content"]
for chunk_content in first_num_classification_chunks
]
)
classification_content = f"{title_string}\n\nVendor Participants:\n{participant_string}\n\n\
Other Participants:\n{account_participant_string}\n\nBeginning of Call:\n{content_string}"
return classification_content

View File

@@ -0,0 +1,21 @@
from pydantic import BaseModel
class ContextPreparation(BaseModel):
"""
Context preparation format for the LLM KG extraction.
"""
llm_context: str
core_entity: str
implied_entities: list[str]
implied_relationships: list[str]
implied_terms: list[str]
class KGDocumentClassificationPrompt(BaseModel):
"""
Document classification prompt format for the LLM KG extraction.
"""
llm_prompt: str | None

View File

@@ -0,0 +1,947 @@
import json
from collections import defaultdict
from collections.abc import Callable
from typing import cast
from typing import Dict
from langchain_core.messages import HumanMessage
from onyx.db.connector import get_unprocessed_connector_ids
from onyx.db.document import get_document_updated_at
from onyx.db.document import get_unprocessed_kg_documents_for_connector
from onyx.db.document import update_document_kg_info
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import add_entity
from onyx.db.entities import get_entity_types
from onyx.db.relationships import add_or_increment_relationship
from onyx.db.relationships import add_relationship
from onyx.db.relationships import add_relationship_type
from onyx.document_index.vespa.index import KGUChunkUpdateRequest
from onyx.document_index.vespa.index import KGUDocumentUpdateRequest
from onyx.document_index.vespa.kg_interactions import update_kg_chunks_vespa_info
from onyx.kg.models import ConnectorExtractionStats
from onyx.kg.models import KGAggregatedExtractions
from onyx.kg.models import KGBatchExtractionStats
from onyx.kg.models import KGChunkExtraction
from onyx.kg.models import KGChunkFormat
from onyx.kg.models import KGChunkId
from onyx.kg.models import KGClassificationContent
from onyx.kg.models import KGClassificationDecisions
from onyx.kg.models import KGClassificationInstructionStrings
from onyx.kg.utils.chunk_preprocessing import prepare_llm_content
from onyx.kg.utils.chunk_preprocessing import prepare_llm_document_content
from onyx.kg.utils.formatting_utils import aggregate_kg_extractions
from onyx.kg.utils.formatting_utils import generalize_entities
from onyx.kg.utils.formatting_utils import generalize_relationships
from onyx.kg.vespa.vespa_interactions import get_document_chunks_for_kg_processing
from onyx.kg.vespa.vespa_interactions import (
get_document_classification_content_for_kg_processing,
)
from onyx.llm.factory import get_default_llms
from onyx.llm.utils import message_to_string
from onyx.prompts.kg_prompts import MASTER_EXTRACTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
def _get_classification_instructions() -> Dict[str, KGClassificationInstructionStrings]:
"""
Prepare the classification instructions for the given source.
"""
classification_instructions_dict: Dict[str, KGClassificationInstructionStrings] = {}
with get_session_with_current_tenant() as db_session:
entity_types = get_entity_types(db_session, active=None)
for entity_type in entity_types:
grounded_source_name = entity_type.grounded_source_name
if grounded_source_name is None:
continue
classification_class_definitions = entity_type.classification_requirements
classification_options = ", ".join(classification_class_definitions.keys())
classification_instructions_dict[
grounded_source_name
] = KGClassificationInstructionStrings(
classification_options=classification_options,
classification_class_definitions=classification_class_definitions,
)
return classification_instructions_dict
def get_entity_types_str(active: bool | None = None) -> str:
"""
Get the entity types from the KGChunkExtraction model.
"""
with get_session_with_current_tenant() as db_session:
active_entity_types = get_entity_types(db_session, active)
entity_types_list = []
for entity_type in active_entity_types:
if entity_type.description:
entity_description = "\n - Description: " + entity_type.description
else:
entity_description = ""
if entity_type.ge_determine_instructions:
allowed_options = "\n - Allowed Options: " + ", ".join(
entity_type.ge_determine_instructions
)
else:
allowed_options = ""
entity_types_list.append(
entity_type.id_name + entity_description + allowed_options
)
return "\n".join(entity_types_list)
def get_relationship_types_str(active: bool | None = None) -> str:
"""
Get the relationship types from the database.
Args:
active: Filter by active status (True, False, or None for all)
Returns:
A string with all relationship types formatted as "source_type__relationship_type__target_type"
"""
from onyx.db.relationships import get_all_relationship_types
with get_session_with_current_tenant() as db_session:
relationship_types = get_all_relationship_types(db_session)
# Filter by active status if specified
if active is not None:
relationship_types = [
rt for rt in relationship_types if rt.active == active
]
relationship_types_list = []
for rel_type in relationship_types:
# Format as "source_type__relationship_type__target_type"
formatted_type = f"{rel_type.source_entity_type_id_name}__{rel_type.type}__{rel_type.target_entity_type_id_name}"
relationship_types_list.append(formatted_type)
return "\n".join(relationship_types_list)
def kg_extraction_initialization(tenant_id: str, num_chunks: int = 1000) -> None:
"""
This extraction will create a random sample of chunks to process in order to perform
clustering and topic modeling.
"""
logger.info(f"Starting kg extraction for tenant {tenant_id}")
def kg_extraction(
tenant_id: str, index_name: str, processing_chunk_batch_size: int = 8
) -> list[ConnectorExtractionStats]:
"""
This extraction will try to extract from all chunks that have not been kg-processed yet.
Approach:
- Get all unprocessed connectors
- For each connector:
- Get all unprocessed documents
- Classify each document to select proper ones
- For each document:
- Get all chunks
- For each chunk:
- Extract entities, relationships, and terms
- make sure for each entity and relationship also the generalized versions are extracted!
- Aggregate results as needed
- Update Vespa and postgres
"""
logger.info(f"Starting kg extraction for tenant {tenant_id}")
with get_session_with_current_tenant() as db_session:
connector_ids = get_unprocessed_connector_ids(db_session)
connector_extraction_stats: list[ConnectorExtractionStats] = []
document_kg_updates: Dict[str, KGUDocumentUpdateRequest] = {}
processing_chunks: list[KGChunkFormat] = []
carryover_chunks: list[KGChunkFormat] = []
connector_aggregated_kg_extractions_list: list[KGAggregatedExtractions] = []
document_classification_instructions = _get_classification_instructions()
for connector_id in connector_ids:
connector_failed_chunk_extractions: list[KGChunkId] = []
connector_succeeded_chunk_extractions: list[KGChunkId] = []
connector_aggregated_kg_extractions: KGAggregatedExtractions = (
KGAggregatedExtractions(
grounded_entities_document_ids=defaultdict(str),
entities=defaultdict(int),
relationships=defaultdict(int),
terms=defaultdict(int),
)
)
with get_session_with_current_tenant() as db_session:
unprocessed_documents = get_unprocessed_kg_documents_for_connector(
db_session,
connector_id,
)
# TODO: restricted for testing only
unprocessed_documents_list = list(unprocessed_documents)[:6]
document_classification_content_list = (
get_document_classification_content_for_kg_processing(
[
unprocessed_document.id
for unprocessed_document in unprocessed_documents_list
],
index_name,
batch_size=processing_chunk_batch_size,
)
)
classification_outcomes: list[tuple[bool, KGClassificationDecisions]] = []
for document_classification_content in document_classification_content_list:
classification_outcomes.extend(
_kg_document_classification(
document_classification_content,
document_classification_instructions,
)
)
documents_to_process = []
for document_to_process, document_classification_outcome in zip(
unprocessed_documents_list, classification_outcomes
):
if (
document_classification_outcome[0]
and document_classification_outcome[1].classification_decision
):
documents_to_process.append(document_to_process)
for document_to_process in documents_to_process:
formatted_chunk_batches = get_document_chunks_for_kg_processing(
document_to_process.id,
index_name,
batch_size=processing_chunk_batch_size,
)
formatted_chunk_batches_list = list(formatted_chunk_batches)
for formatted_chunk_batch in formatted_chunk_batches_list:
processing_chunks.extend(formatted_chunk_batch)
if len(processing_chunks) >= processing_chunk_batch_size:
carryover_chunks.extend(
processing_chunks[processing_chunk_batch_size:]
)
processing_chunks = processing_chunks[:processing_chunk_batch_size]
chunk_processing_batch_results = _kg_chunk_batch_extraction(
processing_chunks, index_name, tenant_id
)
# Consider removing the stats expressions here and rather write to the db(?)
connector_failed_chunk_extractions.extend(
chunk_processing_batch_results.failed
)
connector_succeeded_chunk_extractions.extend(
chunk_processing_batch_results.succeeded
)
aggregated_batch_extractions = (
chunk_processing_batch_results.aggregated_kg_extractions
)
# Update grounded_entities_document_ids (replace values)
connector_aggregated_kg_extractions.grounded_entities_document_ids.update(
aggregated_batch_extractions.grounded_entities_document_ids
)
# Add to entity counts instead of replacing
for entity, count in aggregated_batch_extractions.entities.items():
connector_aggregated_kg_extractions.entities[entity] += count
# Add to relationship counts instead of replacing
for (
relationship,
count,
) in aggregated_batch_extractions.relationships.items():
connector_aggregated_kg_extractions.relationships[
relationship
] += count
# Add to term counts instead of replacing
for term, count in aggregated_batch_extractions.terms.items():
connector_aggregated_kg_extractions.terms[term] += count
connector_extraction_stats.append(
ConnectorExtractionStats(
connector_id=connector_id,
num_failed=len(connector_failed_chunk_extractions),
num_succeeded=len(connector_succeeded_chunk_extractions),
num_processed=len(processing_chunks),
)
)
processing_chunks = carryover_chunks.copy()
carryover_chunks = []
# processes remaining chunks
chunk_processing_batch_results = _kg_chunk_batch_extraction(
processing_chunks, index_name, tenant_id
)
# Consider removing the stats expressions here and rather write to the db(?)
connector_failed_chunk_extractions.extend(chunk_processing_batch_results.failed)
connector_succeeded_chunk_extractions.extend(
chunk_processing_batch_results.succeeded
)
aggregated_batch_extractions = (
chunk_processing_batch_results.aggregated_kg_extractions
)
# Update grounded_entities_document_ids (replace values)
connector_aggregated_kg_extractions.grounded_entities_document_ids.update(
aggregated_batch_extractions.grounded_entities_document_ids
)
# Add to entity counts instead of replacing
for entity, count in aggregated_batch_extractions.entities.items():
connector_aggregated_kg_extractions.entities[entity] += count
# Add to relationship counts instead of replacing
for relationship, count in aggregated_batch_extractions.relationships.items():
connector_aggregated_kg_extractions.relationships[relationship] += count
# Add to term counts instead of replacing
for term, count in aggregated_batch_extractions.terms.items():
connector_aggregated_kg_extractions.terms[term] += count
connector_extraction_stats.append(
ConnectorExtractionStats(
connector_id=connector_id,
num_failed=len(connector_failed_chunk_extractions),
num_succeeded=len(connector_succeeded_chunk_extractions),
num_processed=len(processing_chunks),
)
)
processing_chunks = []
carryover_chunks = []
connector_aggregated_kg_extractions_list.append(
connector_aggregated_kg_extractions
)
# aggregate document updates
for (
processed_document
) in (
unprocessed_documents_list
): # This will need to change if we do not materialize docs
document_kg_updates[processed_document.id] = KGUDocumentUpdateRequest(
document_id=processed_document.id,
entities=set(),
relationships=set(),
terms=set(),
)
updated_chunk_batches = get_document_chunks_for_kg_processing(
processed_document.id,
index_name,
batch_size=processing_chunk_batch_size,
)
for updated_chunk_batch in updated_chunk_batches:
for updated_chunk in updated_chunk_batch:
chunk_entities = updated_chunk.entities
chunk_relationships = updated_chunk.relationships
chunk_terms = updated_chunk.terms
document_kg_updates[processed_document.id].entities.update(
chunk_entities
)
document_kg_updates[processed_document.id].relationships.update(
chunk_relationships
)
document_kg_updates[processed_document.id].terms.update(chunk_terms)
aggregated_kg_extractions = aggregate_kg_extractions(
connector_aggregated_kg_extractions_list
)
with get_session_with_current_tenant() as db_session:
tracked_entity_types = [
x.id_name for x in get_entity_types(db_session, active=None)
]
# Populate the KG database with the extracted entities, relationships, and terms
for (
entity,
extraction_count,
) in aggregated_kg_extractions.entities.items():
if len(entity.split(":")) != 2:
logger.error(
f"Invalid entity {entity} in aggregated_kg_extractions.entities"
)
continue
entity_type, entity_name = entity.split(":")
entity_type = entity_type.upper()
entity_name = entity_name.capitalize()
if entity_type not in tracked_entity_types:
continue
try:
with get_session_with_current_tenant() as db_session:
if (
entity
not in aggregated_kg_extractions.grounded_entities_document_ids
):
add_entity(
db_session=db_session,
entity_type=entity_type,
name=entity_name,
cluster_count=extraction_count,
)
else:
event_time = get_document_updated_at(
entity,
db_session,
)
add_entity(
db_session=db_session,
entity_type=entity_type,
name=entity_name,
cluster_count=extraction_count,
document_id=aggregated_kg_extractions.grounded_entities_document_ids[
entity
],
event_time=event_time,
)
db_session.commit()
except Exception as e:
logger.error(f"Error adding entity {entity} to the database: {e}")
relationship_type_counter: dict[str, int] = defaultdict(int)
for (
relationship,
extraction_count,
) in aggregated_kg_extractions.relationships.items():
relationship_split = relationship.split("__")
source_entity, relationship_type_, target_entity = relationship.split("__")
source_entity = relationship_split[0]
relationship_type = " ".join(relationship_split[1:-1]).replace("__", "_")
target_entity = relationship_split[-1]
source_entity_type = source_entity.split(":")[0]
target_entity_type = target_entity.split(":")[0]
if (
source_entity_type not in tracked_entity_types
or target_entity_type not in tracked_entity_types
):
continue
source_entity_general = f"{source_entity_type.upper()}"
target_entity_general = f"{target_entity_type.upper()}"
relationship_type_id_name = (
f"{source_entity_general}__{relationship_type.lower()}__"
f"{target_entity_general}"
)
relationship_type_counter[relationship_type_id_name] += extraction_count
for (
relationship_type_id_name,
extraction_count,
) in relationship_type_counter.items():
(
source_entity_type,
relationship_type,
target_entity_type,
) = relationship_type_id_name.split("__")
if (
source_entity_type not in tracked_entity_types
or target_entity_type not in tracked_entity_types
):
continue
try:
with get_session_with_current_tenant() as db_session:
try:
add_relationship_type(
db_session=db_session,
source_entity_type=source_entity_type.upper(),
relationship_type=relationship_type,
target_entity_type=target_entity_type.upper(),
definition=False,
extraction_count=extraction_count,
)
db_session.commit()
except Exception as e:
logger.error(
f"Error adding relationship type {relationship_type_id_name} to the database: {e}"
)
except Exception as e:
logger.error(
f"Error adding relationship type {relationship_type_id_name} to the database: {e}"
)
for (
relationship,
extraction_count,
) in aggregated_kg_extractions.relationships.items():
relationship_split = relationship.split("__")
source_entity, relationship_type_, target_entity = relationship.split("__")
source_entity = relationship_split[0]
relationship_type = (
" ".join(relationship_split[1:-1]).replace("__", " ").replace("_", " ")
)
target_entity = relationship_split[-1]
source_entity_type = source_entity.split(":")[0]
target_entity_type = target_entity.split(":")[0]
try:
with get_session_with_current_tenant() as db_session:
add_relationship(db_session, relationship, extraction_count)
db_session.commit()
except Exception as e:
logger.error(
f"Error adding relationship {relationship} to the database: {e}"
)
with get_session_with_current_tenant() as db_session:
add_or_increment_relationship(db_session, relationship)
db_session.commit()
# Populate the Documents table with the kg information for the documents
for document_id, document_kg_update in document_kg_updates.items():
with get_session_with_current_tenant() as db_session:
update_document_kg_info(
db_session,
document_id,
kg_processed=True,
kg_data={
"entities": list(document_kg_update.entities),
"relationships": list(document_kg_update.relationships),
"terms": list(document_kg_update.terms),
},
)
db_session.commit()
return connector_extraction_stats
def _kg_chunk_batch_extraction(
chunks: list[KGChunkFormat],
index_name: str,
tenant_id: str,
) -> KGBatchExtractionStats:
_, fast_llm = get_default_llms()
succeeded_chunk_id: list[KGChunkId] = []
failed_chunk_id: list[KGChunkId] = []
succeeded_chunk_extraction: list[KGChunkExtraction] = []
preformatted_prompt = MASTER_EXTRACTION_PROMPT.format(
entity_types=get_entity_types_str(active=True)
)
def process_single_chunk(
chunk: KGChunkFormat, preformatted_prompt: str
) -> tuple[bool, KGUChunkUpdateRequest]:
"""Process a single chunk and return success status and chunk ID."""
# For now, we're just processing the content
# TODO: Implement actual prompt application logic
llm_preprocessing = prepare_llm_content(chunk)
formatted_prompt = preformatted_prompt.replace(
"---content---", llm_preprocessing.llm_context
)
msg = [
HumanMessage(
content=formatted_prompt,
)
]
try:
logger.info(
f"LLM Extraction from chunk {chunk.chunk_id} from doc {chunk.document_id}"
)
raw_extraction_result = fast_llm.invoke(msg)
extraction_result = message_to_string(raw_extraction_result)
try:
cleaned_result = (
extraction_result.replace("```json", "").replace("```", "").strip()
)
parsed_result = json.loads(cleaned_result)
extracted_entities = parsed_result.get("entities", [])
extracted_relationships = parsed_result.get("relationships", [])
extracted_terms = parsed_result.get("terms", [])
implied_extracted_relationships = [
llm_preprocessing.core_entity
+ "__"
+ "relates_neutrally_to"
+ "__"
+ entity
for entity in extracted_entities
]
all_entities = set(
list(extracted_entities)
+ list(llm_preprocessing.implied_entities)
+ list(
generalize_entities(
extracted_entities + llm_preprocessing.implied_entities
)
)
)
logger.info(f"All entities: {all_entities}")
all_relationships = (
extracted_relationships
+ llm_preprocessing.implied_relationships
+ implied_extracted_relationships
)
all_relationships = set(
list(all_relationships)
+ list(generalize_relationships(all_relationships))
)
kg_updates = [
KGUChunkUpdateRequest(
document_id=chunk.document_id,
chunk_id=chunk.chunk_id,
core_entity=llm_preprocessing.core_entity,
entities=all_entities,
relationships=all_relationships,
terms=set(extracted_terms),
),
]
update_kg_chunks_vespa_info(
kg_update_requests=kg_updates,
index_name=index_name,
tenant_id=tenant_id,
)
logger.info(
f"KG updated: {chunk.chunk_id} from doc {chunk.document_id}"
)
return True, kg_updates[0] # only single chunk
except json.JSONDecodeError as e:
logger.error(
f"Invalid JSON format for extraction of chunk {chunk.chunk_id} \
from doc {chunk.document_id}: {str(e)}"
)
logger.error(f"Raw output: {extraction_result}")
return False, KGUChunkUpdateRequest(
document_id=chunk.document_id,
chunk_id=chunk.chunk_id,
core_entity=llm_preprocessing.core_entity,
entities=set(),
relationships=set(),
terms=set(),
)
except Exception as e:
logger.error(
f"Failed to process chunk {chunk.chunk_id} from doc {chunk.document_id}: {str(e)}"
)
return False, KGUChunkUpdateRequest(
document_id=chunk.document_id,
chunk_id=chunk.chunk_id,
core_entity="",
entities=set(),
relationships=set(),
terms=set(),
)
# Assume for prototype: use_threads = True. TODO: Make thread safe!
functions_with_args: list[tuple[Callable, tuple]] = [
(process_single_chunk, (chunk, preformatted_prompt)) for chunk in chunks
]
logger.debug("Running KG extraction on chunks in parallel")
results = run_functions_tuples_in_parallel(functions_with_args, allow_failures=True)
# Sort results into succeeded and failed
for success, chunk_results in results:
if success:
succeeded_chunk_id.append(
KGChunkId(
document_id=chunk_results.document_id,
chunk_id=chunk_results.chunk_id,
)
)
succeeded_chunk_extraction.append(chunk_results)
else:
failed_chunk_id.append(
KGChunkId(
document_id=chunk_results.document_id,
chunk_id=chunk_results.chunk_id,
)
)
# Collect data for postgres later on
aggregated_kg_extractions = KGAggregatedExtractions(
grounded_entities_document_ids=defaultdict(str),
entities=defaultdict(int),
relationships=defaultdict(int),
terms=defaultdict(int),
)
for chunk_result in succeeded_chunk_extraction:
aggregated_kg_extractions.grounded_entities_document_ids[
chunk_result.core_entity
] = chunk_result.document_id
mentioned_chunk_entities: set[str] = set()
for relationship in chunk_result.relationships:
relationship_split = relationship.split("__")
if len(relationship_split) == 3:
if relationship_split[0] not in mentioned_chunk_entities:
aggregated_kg_extractions.entities[relationship_split[0]] += 1
mentioned_chunk_entities.add(relationship_split[0])
if relationship_split[2] not in mentioned_chunk_entities:
aggregated_kg_extractions.entities[relationship_split[2]] += 1
mentioned_chunk_entities.add(relationship_split[2])
aggregated_kg_extractions.relationships[relationship] += 1
for kg_entity in chunk_result.entities:
if kg_entity not in mentioned_chunk_entities:
aggregated_kg_extractions.entities[kg_entity] += 1
mentioned_chunk_entities.add(kg_entity)
for kg_term in chunk_result.terms:
aggregated_kg_extractions.terms[kg_term] += 1
return KGBatchExtractionStats(
connector_id=chunks[0].connector_id if chunks else None, # TODO: Update!
succeeded=succeeded_chunk_id,
failed=failed_chunk_id,
aggregated_kg_extractions=aggregated_kg_extractions,
)
def _kg_document_classification(
document_classification_content_list: list[KGClassificationContent],
classification_instructions: dict[str, KGClassificationInstructionStrings],
) -> list[tuple[bool, KGClassificationDecisions]]:
primary_llm, fast_llm = get_default_llms()
def classify_single_document(
document_classification_content: KGClassificationContent,
classification_instructions: dict[str, KGClassificationInstructionStrings],
) -> tuple[bool, KGClassificationDecisions]:
"""Classify a single document whether it should be kg-processed or not"""
source = document_classification_content.source_type
document_id = document_classification_content.document_id
if source not in classification_instructions:
logger.info(
f"Source {source} did not have kg classification instructions. No content analysis."
)
return False, KGClassificationDecisions(
document_id=document_id,
classification_decision=False,
classification_class=None,
)
classification_prompt = prepare_llm_document_content(
document_classification_content,
category_list=classification_instructions[source].classification_options,
category_definitions=classification_instructions[
source
].classification_class_definitions,
)
if classification_prompt.llm_prompt is None:
logger.info(
f"Source {source} did not have kg document classification instructions. No content analysis."
)
return False, KGClassificationDecisions(
document_id=document_id,
classification_decision=False,
classification_class=None,
)
msg = [
HumanMessage(
content=classification_prompt.llm_prompt,
)
]
try:
logger.info(
f"LLM Classification from document {document_classification_content.document_id}"
)
raw_classification_result = primary_llm.invoke(msg)
classification_result = (
message_to_string(raw_classification_result)
.replace("```json", "")
.replace("```", "")
.strip()
)
classification_class = classification_result.split("CATEGORY:")[1].strip()
if (
classification_class
in classification_instructions[source].classification_class_definitions
):
extraction_decision = cast(
bool,
classification_instructions[
source
].classification_class_definitions[classification_class][
"extraction"
],
)
else:
extraction_decision = False
return True, KGClassificationDecisions(
document_id=document_id,
classification_decision=extraction_decision,
classification_class=classification_class,
)
except Exception as e:
logger.error(
f"Failed to classify document {document_classification_content.document_id}: {str(e)}"
)
return False, KGClassificationDecisions(
document_id=document_id,
classification_decision=False,
classification_class=None,
)
# Assume for prototype: use_threads = True. TODO: Make thread safe!
functions_with_args: list[tuple[Callable, tuple]] = [
(
classify_single_document,
(document_classification_content, classification_instructions),
)
for document_classification_content in document_classification_content_list
]
logger.debug("Running KG classification on documents in parallel")
results = run_functions_tuples_in_parallel(functions_with_args, allow_failures=True)
return results
# logger.debug("Running KG extraction on chunks in parallel")
# results = run_functions_tuples_in_parallel(functions_with_args, allow_failures=True)
# # Sort results into succeeded and failed
# for success, chunk_results in results:
# if success:
# succeeded_chunk_id.append(
# KGChunkId(
# document_id=chunk_results.document_id,
# chunk_id=chunk_results.chunk_id,
# )
# )
# succeeded_chunk_extraction.append(chunk_results)
# else:
# failed_chunk_id.append(
# KGChunkId(
# document_id=chunk_results.document_id,
# chunk_id=chunk_results.chunk_id,
# )
# )
# # Collect data for postgres later on
# aggregated_kg_extractions = KGAggregatedExtractions(
# grounded_entities_document_ids=defaultdict(str),
# entities=defaultdict(int),
# relationships=defaultdict(int),
# terms=defaultdict(int),
# )
# for chunk_result in succeeded_chunk_extraction:
# aggregated_kg_extractions.grounded_entities_document_ids[
# chunk_result.core_entity
# ] = chunk_result.document_id
# mentioned_chunk_entities: set[str] = set()
# for relationship in chunk_result.relationships:
# relationship_split = relationship.split("__")
# if len(relationship_split) == 3:
# if relationship_split[0] not in mentioned_chunk_entities:
# aggregated_kg_extractions.entities[relationship_split[0]] += 1
# mentioned_chunk_entities.add(relationship_split[0])
# if relationship_split[2] not in mentioned_chunk_entities:
# aggregated_kg_extractions.entities[relationship_split[2]] += 1
# mentioned_chunk_entities.add(relationship_split[2])
# aggregated_kg_extractions.relationships[relationship] += 1
# for kg_entity in chunk_result.entities:
# if kg_entity not in mentioned_chunk_entities:
# aggregated_kg_extractions.entities[kg_entity] += 1
# mentioned_chunk_entities.add(kg_entity)
# for kg_term in chunk_result.terms:
# aggregated_kg_extractions.terms[kg_term] += 1
# return KGBatchExtractionStats(
# connector_id=chunks[0].connector_id if chunks else None, # TODO: Update!
# succeeded=succeeded_chunk_id,
# failed=failed_chunk_id,
# aggregated_kg_extractions=aggregated_kg_extractions,
# )
def _kg_connector_extraction(
connector_id: str,
tenant_id: str,
) -> None:
logger.info(
f"Starting kg extraction for connector {connector_id} for tenant {tenant_id}"
)
# - grab kg type data from postgres
# - construct prompt
# find all documents for the connector that have not been kg-processed
# - loop for :
# - grab a number of chunks from vespa
# - convert them into the KGChunk format
# - run the extractions in parallel
# - save the results
# - mark chunks as processed
# - update the connector status
#

View File

@@ -0,0 +1,86 @@
from onyx.utils.logger import setup_logger
logger = setup_logger()
"""
def kg_extraction_section(
kg_chunks: list[KGChunk],
llm: LLM,
max_workers: int = 10,
) -> list[KGChunkExtraction]:
def _get_metadata_str(metadata: dict[str, str | list[str]]) -> str:
metadata_str = "\nMetadata:\n"
for key, value in metadata.items():
value_str = ", ".join(value) if isinstance(value, list) else value
metadata_str += f"{key} - {value_str}\n"
return metadata_str
def _get_usefulness_messages() -> list[dict[str, str]]:
metadata_str = _get_metadata_str(metadata) if metadata else ""
messages = [
{
"role": "user",
"content": SECTION_FILTER_PROMPT.format(
title=title.replace("\n", " "),
chunk_text=section_content,
user_query=query,
optional_metadata=metadata_str,
),
},
]
return messages
def _extract_usefulness(model_output: str) -> bool:
if model_output.strip().strip('"').lower() == NONUSEFUL_PAT.lower():
return False
return True
messages = _get_usefulness_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = message_to_string(llm.invoke(filled_llm_prompt))
logger.debug(model_output)
return _extract_usefulness(model_output)
"""
"""
def llm_batch_eval_sections(
query: str,
section_contents: list[str],
llm: LLM,
titles: list[str],
metadata_list: list[dict[str, str | list[str]]],
use_threads: bool = True,
) -> list[bool]:
if DISABLE_LLM_DOC_RELEVANCE:
raise RuntimeError(
"LLM Doc Relevance is globally disabled, "
"this should have been caught upstream."
)
if use_threads:
functions_with_args: list[tuple[Callable, tuple]] = [
(llm_eval_section, (query, section_content, llm, title, metadata))
for section_content, title, metadata in zip(
section_contents, titles, metadata_list
)
]
logger.debug(
"Running LLM usefulness eval in parallel (following logging may be out of order)"
)
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=True
)
# In case of failure/timeout, don't throw out the section
return [True if item is None else item for item in parallel_results]
else:
return [
llm_eval_section(query, section_content, llm, title, metadata)
for section_content, title, metadata in zip(
section_contents, titles, metadata_list
)
]
"""

99
backend/onyx/kg/models.py Normal file
View File

@@ -0,0 +1,99 @@
from collections import defaultdict
from typing import Dict
from pydantic import BaseModel
class KGChunkFormat(BaseModel):
connector_id: int | None = None
document_id: str
chunk_id: int
title: str
content: str
primary_owners: list[str]
secondary_owners: list[str]
source_type: str
metadata: dict[str, str | list[str]] | None = None
entities: Dict[str, int] = {}
relationships: Dict[str, int] = {}
terms: Dict[str, int] = {}
class KGChunkExtraction(BaseModel):
connector_id: int
document_id: str
chunk_id: int
core_entity: str
entities: list[str]
relationships: list[str]
terms: list[str]
class KGChunkId(BaseModel):
connector_id: int | None = None
document_id: str
chunk_id: int
class KGAggregatedExtractions(BaseModel):
grounded_entities_document_ids: defaultdict[str, str]
entities: defaultdict[str, int]
relationships: defaultdict[str, int]
terms: defaultdict[str, int]
class KGBatchExtractionStats(BaseModel):
connector_id: int | None = None
succeeded: list[KGChunkId]
failed: list[KGChunkId]
aggregated_kg_extractions: KGAggregatedExtractions
class ConnectorExtractionStats(BaseModel):
connector_id: int
num_succeeded: int
num_failed: int
num_processed: int
class KGPerson(BaseModel):
name: str
company: str
employee: bool
class NormalizedEntities(BaseModel):
entities: list[str]
entity_normalization_map: dict[str, str | None]
class NormalizedRelationships(BaseModel):
relationships: list[str]
relationship_normalization_map: dict[str, str | None]
class NormalizedTerms(BaseModel):
terms: list[str]
term_normalization_map: dict[str, str | None]
class KGClassificationContent(BaseModel):
document_id: str
classification_content: str
source_type: str
class KGClassificationDecisions(BaseModel):
document_id: str
classification_decision: bool
classification_class: str | None
class KGClassificationRule(BaseModel):
description: str
extration: bool
class KGClassificationInstructionStrings(BaseModel):
classification_options: str
classification_class_definitions: dict[str, Dict[str, str | bool]]

View File

@@ -0,0 +1,55 @@
from typing import Dict
from onyx.kg.context_preparations_extraction.fireflies import (
prepare_llm_content_fireflies,
)
from onyx.kg.context_preparations_extraction.fireflies import (
prepare_llm_document_content_fireflies,
)
from onyx.kg.context_preparations_extraction.models import ContextPreparation
from onyx.kg.context_preparations_extraction.models import (
KGDocumentClassificationPrompt,
)
from onyx.kg.models import KGChunkFormat
from onyx.kg.models import KGClassificationContent
def prepare_llm_content(chunk: KGChunkFormat) -> ContextPreparation:
"""
Prepare the content for the LLM.
"""
if chunk.source_type == "fireflies":
return prepare_llm_content_fireflies(chunk)
else:
return ContextPreparation(
llm_context=chunk.content,
core_entity="",
implied_entities=[],
implied_relationships=[],
implied_terms=[],
)
def prepare_llm_document_content(
document_classification_content: KGClassificationContent,
category_list: str,
category_definitions: dict[str, Dict[str, str | bool]],
) -> KGDocumentClassificationPrompt:
"""
Prepare the content for the extraction classification.
"""
category_definition_string = ""
for category, category_data in category_definitions.items():
category_definition_string += f"{category}: {category_data['description']}\n"
if document_classification_content.source_type == "fireflies":
return prepare_llm_document_content_fireflies(
document_classification_content, category_list, category_definition_string
)
else:
return KGDocumentClassificationPrompt(
llm_prompt=None,
)

View File

@@ -0,0 +1,41 @@
from sqlalchemy import select
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import Connector
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_unprocessed_connector_ids(tenant_id: str) -> list[int]:
"""
Retrieves a list of connector IDs that have not been KG processed for a given tenant.
Args:
tenant_id (str): The ID of the tenant to check for unprocessed connectors
Returns:
list[int]: List of connector IDs that have enabled KG extraction but have unprocessed documents
"""
try:
with get_session_with_current_tenant() as db_session:
# Find connectors that:
# 1. Have KG extraction enabled
# 2. Have documents that haven't been KG processed
stmt = (
select(Connector.id)
.distinct()
.join(DocumentByConnectorCredentialPair)
.where(
Connector.kg_extraction_enabled,
DocumentByConnectorCredentialPair.has_been_kg_processed.is_(False),
)
)
result = db_session.execute(stmt)
return [row[0] for row in result.fetchall()]
except Exception as e:
logger.error(f"Error fetching unprocessed connector IDs: {str(e)}")
return []

View File

@@ -0,0 +1,23 @@
from typing import List
import numpy as np
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.search_settings import get_current_search_settings
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import EmbedTextType
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
def encode_string_batch(strings: List[str]) -> np.ndarray:
with get_session_with_current_tenant() as db_session:
current_search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=current_search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
# Get embeddings while session is still open
embedding = model.encode(strings, text_type=EmbedTextType.QUERY)
return np.array(embedding)

View File

@@ -0,0 +1,123 @@
from collections import defaultdict
from onyx.configs.kg_configs import KG_OWN_COMPANY
from onyx.configs.kg_configs import KG_OWN_EMAIL_DOMAINS
from onyx.kg.models import KGAggregatedExtractions
from onyx.kg.models import KGPerson
def format_entity(entity: str) -> str:
if len(entity.split(":")) == 2:
entity_type, entity_name = entity.split(":")
return f"{entity_type.upper()}:{entity_name.title()}"
else:
return entity
def format_relationship(relationship: str) -> str:
source_node, relationship_type, target_node = relationship.split("__")
return (
f"{format_entity(source_node)}__"
f"{relationship_type.lower()}__"
f"{format_entity(target_node)}"
)
def format_relationship_type(relationship_type: str) -> str:
source_node_type, relationship_type, target_node_type = relationship_type.split(
"__"
)
return (
f"{source_node_type.upper()}__"
f"{relationship_type.lower()}__"
f"{target_node_type.upper()}"
)
def generate_relationship_type(relationship: str) -> str:
source_node, relationship_type, target_node = relationship.split("__")
return (
f"{source_node.split(':')[0].upper()}__"
f"{relationship_type.lower()}__"
f"{target_node.split(':')[0].upper()}"
)
def aggregate_kg_extractions(
connector_aggregated_kg_extractions_list: list[KGAggregatedExtractions],
) -> KGAggregatedExtractions:
aggregated_kg_extractions = KGAggregatedExtractions(
grounded_entities_document_ids=defaultdict(str),
entities=defaultdict(int),
relationships=defaultdict(int),
terms=defaultdict(int),
)
for connector_aggregated_kg_extractions in connector_aggregated_kg_extractions_list:
for (
grounded_entity,
document_id,
) in connector_aggregated_kg_extractions.grounded_entities_document_ids.items():
aggregated_kg_extractions.grounded_entities_document_ids[
grounded_entity
] = document_id
for entity, count in connector_aggregated_kg_extractions.entities.items():
aggregated_kg_extractions.entities[entity] += count
for (
relationship,
count,
) in connector_aggregated_kg_extractions.relationships.items():
aggregated_kg_extractions.relationships[relationship] += count
for term, count in connector_aggregated_kg_extractions.terms.items():
aggregated_kg_extractions.terms[term] += count
return aggregated_kg_extractions
def kg_email_processing(email: str) -> KGPerson:
"""
Process the email.
"""
name, company_domain = email.split("@")
assert isinstance(KG_OWN_EMAIL_DOMAINS, list)
employee = any(domain in company_domain for domain in KG_OWN_EMAIL_DOMAINS)
if employee:
company = KG_OWN_COMPANY
else:
company = company_domain.capitalize()
return KGPerson(name=name, company=company, employee=employee)
def generalize_entities(entities: list[str]) -> set[str]:
"""
Generalize entities to their superclass.
"""
return set([f"{entity.split(':')[0]}:*" for entity in entities])
def generalize_relationships(relationships: list[str]) -> set[str]:
"""
Generalize relationships to their superclass.
"""
generalized_relationships: set[str] = set()
for relationship in relationships:
assert (
len(relationship.split("__")) == 3
), "Relationship is not in the correct format"
source_entity, relationship_type, target_entity = relationship.split("__")
generalized_source_entity = list(generalize_entities([source_entity]))[0]
generalized_target_entity = list(generalize_entities([target_entity]))[0]
generalized_relationships.add(
f"{generalized_source_entity}__{relationship_type}__{target_entity}"
)
generalized_relationships.add(
f"{source_entity}__{relationship_type}__{generalized_target_entity}"
)
generalized_relationships.add(
f"{generalized_source_entity}__{relationship_type}__{generalized_target_entity}"
)
return generalized_relationships

View File

@@ -0,0 +1,179 @@
import json
from collections.abc import Generator
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
from onyx.document_index.vespa.index import IndexFilters
from onyx.kg.context_preparations_extraction.fireflies import (
get_classification_content_from_fireflies_chunks,
)
from onyx.kg.models import KGChunkFormat
from onyx.kg.models import KGClassificationContent
def get_document_classification_content_for_kg_processing(
document_ids: list[str],
index_name: str,
batch_size: int = 8,
num_classification_chunks: int = 3,
) -> Generator[list[KGClassificationContent], None, None]:
"""
Generates the content used for initial classification of a document from
the first num_classification_chunks chunks.
"""
classification_content_list: list[KGClassificationContent] = []
for i in range(0, len(document_ids), batch_size):
batch_document_ids = document_ids[i : i + batch_size]
for document_id in batch_document_ids:
# ... existing code for getting chunks and processing ...
first_num_classification_chunks: list[dict] = _get_chunks_via_visit_api(
chunk_request=VespaChunkRequest(
document_id=document_id,
max_chunk_ind=num_classification_chunks - 1,
min_chunk_ind=0,
),
index_name=index_name,
filters=IndexFilters(access_control_list=None),
field_names=[
"document_id",
"chunk_id",
"title",
"content",
"metadata",
"source_type",
"primary_owners",
"secondary_owners",
],
get_large_chunks=False,
)
first_num_classification_chunks = sorted(
first_num_classification_chunks, key=lambda x: x["fields"]["chunk_id"]
)[:num_classification_chunks]
classification_content = _get_classification_content_from_chunks(
first_num_classification_chunks
)
classification_content_list.append(
KGClassificationContent(
document_id=document_id,
classification_content=classification_content,
source_type=first_num_classification_chunks[0]["fields"][
"source_type"
],
)
)
# Yield the batch of classification content
if classification_content_list:
yield classification_content_list
classification_content_list = []
# Yield any remaining items
if classification_content_list:
yield classification_content_list
def get_document_chunks_for_kg_processing(
document_id: str,
index_name: str,
batch_size: int = 8,
) -> Generator[list[KGChunkFormat], None, None]:
"""
Retrieves chunks from Vespa for the given document IDs and converts them to KGChunks.
Args:
document_ids (list[str]): List of document IDs to fetch chunks for
index_name (str): Name of the Vespa index
tenant_id (str): ID of the tenant
Yields:
list[KGChunk]: Batches of chunks ready for KG processing
"""
current_batch: list[KGChunkFormat] = []
# get all chunks for the document
chunks = _get_chunks_via_visit_api(
chunk_request=VespaChunkRequest(document_id=document_id),
index_name=index_name,
filters=IndexFilters(access_control_list=None),
field_names=[
"document_id",
"chunk_id",
"title",
"content",
"metadata",
"primary_owners",
"secondary_owners",
"source_type",
"kg_entities",
"kg_relationships",
"kg_terms",
],
get_large_chunks=False,
)
# Convert Vespa chunks to KGChunks
# kg_chunks: list[KGChunkFormat] = []
for i, chunk in enumerate(chunks):
fields = chunk["fields"]
if isinstance(fields.get("metadata", {}), str):
fields["metadata"] = json.loads(fields["metadata"])
current_batch.append(
KGChunkFormat(
connector_id=None, # We may need to adjust this
document_id=fields.get("document_id"),
chunk_id=fields.get("chunk_id"),
primary_owners=fields.get("primary_owners", []),
secondary_owners=fields.get("secondary_owners", []),
source_type=fields.get("source_type", ""),
title=fields.get("title", ""),
content=fields.get("content", ""),
metadata=fields.get("metadata", {}),
entities=fields.get("kg_entities", {}),
relationships=fields.get("kg_relationships", {}),
terms=fields.get("kg_terms", {}),
)
)
if len(current_batch) >= batch_size:
yield current_batch
current_batch = []
# Yield any remaining chunks
if current_batch:
yield current_batch
def _get_classification_content_from_chunks(
first_num_classification_chunks: list[dict],
) -> str:
"""
Creates a KGClassificationContent object from a list of chunks.
"""
source_type = first_num_classification_chunks[0]["fields"]["source_type"]
if source_type == "fireflies":
classification_content = get_classification_content_from_fireflies_chunks(
first_num_classification_chunks
)
else:
classification_content = (
first_num_classification_chunks[0]["fields"]["title"]
+ "\n"
+ "\n".join(
[
chunk_content["fields"]["content"]
for chunk_content in first_num_classification_chunks
]
)
)
return classification_content

File diff suppressed because it is too large Load Diff

View File

@@ -75,12 +75,17 @@ class SearchToolOverrideKwargs(BaseModel):
precomputed_keywords: list[str] | None = None
user_file_ids: list[int] | None = None
user_folder_ids: list[int] | None = None
document_sources: list[DocumentSource] | None = None
time_cutoff: datetime | None = None
ordering_only: bool | None = (
None # Flag for fast path when search is only needed for ordering
)
document_sources: list[DocumentSource] | None = None
time_cutoff: datetime | None = None
expanded_queries: QueryExpansions | None = None
kg_entities: list[str] | None = None
kg_relationships: list[str] | None = None
kg_terms: list[str] | None = None
class Config:
arbitrary_types_allowed = True

View File

@@ -309,6 +309,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
document_sources = override_kwargs.document_sources
time_cutoff = override_kwargs.time_cutoff
expanded_queries = override_kwargs.expanded_queries
kg_entities = override_kwargs.kg_entities
kg_relationships = override_kwargs.kg_relationships
kg_terms = override_kwargs.kg_terms
# Fast path for ordering-only search
if ordering_only:
@@ -358,6 +361,16 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
# Overwrite time-cutoff should supercede existing time-cutoff, even if defined
retrieval_options.filters.time_cutoff = time_cutoff
# Initialize kg filters in retrieval options and filters with all provided values
retrieval_options = retrieval_options or RetrievalDetails()
retrieval_options.filters = retrieval_options.filters or BaseFilters()
if kg_entities:
retrieval_options.filters.kg_entities = kg_entities
if kg_relationships:
retrieval_options.filters.kg_relationships = kg_relationships
if kg_terms:
retrieval_options.filters.kg_terms = kg_terms
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=query,

View File

@@ -80,6 +80,7 @@ slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.15
starlette==0.46.1
supervisor==4.2.5
thefuzz==0.22.1
tiktoken==0.7.0
timeago==1.0.16
transformers==4.49.0