Compare commits

..

2 Commits

Author SHA1 Message Date
pablonyx
bbf5fa13dc update 2025-04-11 15:13:02 -07:00
pablonyx
a4a399bc31 update 2025-04-11 15:11:35 -07:00
126 changed files with 1077 additions and 10974 deletions

View File

@@ -1,219 +0,0 @@
"""create knowlege graph tables
Revision ID: 495cb26ce93e
Revises: 6a804aeb4830
Create Date: 2025-03-19 08:51:14.341989
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "495cb26ce93e"
down_revision = "6a804aeb4830"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"kg_entity_type",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("description", sa.String(), nullable=True),
sa.Column("grounding", sa.String(), nullable=False),
sa.Column("clustering", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column(
"classification_requirements",
postgresql.JSONB,
nullable=False,
server_default="{}",
),
sa.Column("cluster_count", sa.Integer(), nullable=True),
sa.Column(
"extraction_sources", postgresql.JSONB, nullable=False, server_default="{}"
),
sa.Column("active", sa.Boolean(), nullable=False, default=False),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.Column("grounded_source_name", sa.String(), nullable=True, unique=True),
sa.Column(
"ge_determine_instructions", postgresql.ARRAY(sa.String()), nullable=True
),
sa.Column("ge_grounding_signature", sa.String(), nullable=True),
)
# Create KGRelationshipType table
op.create_table(
"kg_relationship_type",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("name", sa.String(), nullable=False, index=True),
sa.Column(
"source_entity_type_id_name", sa.String(), nullable=False, index=True
),
sa.Column(
"target_entity_type_id_name", sa.String(), nullable=False, index=True
),
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
sa.Column("clustering", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("cluster_count", sa.Integer(), nullable=True),
sa.Column("type", sa.String(), nullable=False, index=True),
sa.Column("active", sa.Boolean(), nullable=False, default=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.ForeignKeyConstraint(
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
),
sa.ForeignKeyConstraint(
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
),
)
# Create KGEntity table
op.create_table(
"kg_entity",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("name", sa.String(), nullable=False, index=True),
sa.Column("document_id", sa.String(), nullable=True, index=True),
sa.Column(
"alternative_names",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
sa.Column("description", sa.String(), nullable=True),
sa.Column(
"keywords",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column("cluster_count", sa.Integer(), nullable=True),
sa.Column(
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
),
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
)
op.create_index("ix_entity_type_acl", "kg_entity", ["entity_type_id_name", "acl"])
op.create_index(
"ix_entity_name_search", "kg_entity", ["name", "entity_type_id_name"]
)
# Create KGRelationship table
op.create_table(
"kg_relationship",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("source_node", sa.String(), nullable=False, index=True),
sa.Column("target_node", sa.String(), nullable=False, index=True),
sa.Column("type", sa.String(), nullable=False, index=True),
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
sa.Column("cluster_count", sa.Integer(), nullable=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.ForeignKeyConstraint(["source_node"], ["kg_entity.id_name"]),
sa.ForeignKeyConstraint(["target_node"], ["kg_entity.id_name"]),
sa.ForeignKeyConstraint(
["relationship_type_id_name"], ["kg_relationship_type.id_name"]
),
sa.UniqueConstraint(
"source_node",
"target_node",
"type",
name="uq_kg_relationship_source_target_type",
),
)
op.create_index(
"ix_kg_relationship_nodes", "kg_relationship", ["source_node", "target_node"]
)
# Create KGTerm table
op.create_table(
"kg_term",
sa.Column("id_term", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column(
"entity_types",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
)
op.create_index("ix_search_term_entities", "kg_term", ["entity_types"])
op.create_index("ix_search_term_term", "kg_term", ["id_term"])
op.add_column(
"document",
sa.Column("kg_processed", sa.Boolean(), nullable=False, server_default="false"),
)
op.add_column(
"document",
sa.Column("kg_data", postgresql.JSONB(), nullable=False, server_default="{}"),
)
op.add_column(
"connector",
sa.Column(
"kg_extraction_enabled",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
op.add_column(
"document_by_connector_credential_pair",
sa.Column("has_been_kg_processed", sa.Boolean(), nullable=True),
)
def downgrade() -> None:
# Drop tables in reverse order of creation to handle dependencies
op.drop_table("kg_term")
op.drop_table("kg_relationship")
op.drop_table("kg_entity")
op.drop_table("kg_relationship_type")
op.drop_table("kg_entity_type")
op.drop_column("connector", "kg_extraction_enabled")
op.drop_column("document_by_connector_credential_pair", "has_been_kg_processed")
op.drop_column("document", "kg_data")
op.drop_column("document", "kg_processed")

View File

@@ -1,57 +0,0 @@
"""Update status length
Revision ID: d961aca62eb3
Revises: cf90764725d8
Create Date: 2025-03-23 16:10:05.683965
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d961aca62eb3"
down_revision = "cf90764725d8"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop the existing enum type constraint
op.execute("ALTER TABLE connector_credential_pair ALTER COLUMN status TYPE varchar")
# Create new enum type with all values
op.execute(
"ALTER TABLE connector_credential_pair ALTER COLUMN status TYPE VARCHAR(20) USING status::varchar(20)"
)
# Update the enum type to include all possible values
op.alter_column(
"connector_credential_pair",
"status",
type_=sa.Enum(
"SCHEDULED",
"INITIAL_INDEXING",
"ACTIVE",
"PAUSED",
"DELETING",
"INVALID",
name="connectorcredentialpairstatus",
native_enum=False,
),
existing_type=sa.String(20),
nullable=False,
)
op.add_column(
"connector_credential_pair",
sa.Column(
"in_repeated_error_state", sa.Boolean, default=False, server_default="false"
),
)
def downgrade() -> None:
# no need to convert back to the old enum type, since we're not using it anymore
op.drop_column("connector_credential_pair", "in_repeated_error_state")

View File

@@ -140,7 +140,7 @@ def fetch_onyxbot_analytics(
(
or_(
ChatMessageFeedback.is_positive.is_(False),
ChatMessageFeedback.required_followup.is_(True),
ChatMessageFeedback.required_followup,
),
1,
),
@@ -173,7 +173,7 @@ def fetch_onyxbot_analytics(
.all()
)
return [tuple(row) for row in results]
return results
def fetch_persona_message_analytics(

View File

@@ -406,6 +406,7 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
headers=headers,
json=payload.model_dump(),
) as response:
print(response)
if response.status != 200:
error_text = await response.text()
logger.error(f"Control plane tenant creation failed: {error_text}")

View File

@@ -1,5 +1,3 @@
from typing import cast
import numpy as np
import torch
import torch.nn.functional as F
@@ -41,10 +39,10 @@ logger = setup_logger()
router = APIRouter(prefix="/custom")
_CONNECTOR_CLASSIFIER_TOKENIZER: PreTrainedTokenizer | None = None
_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: PreTrainedTokenizer | None = None
_INTENT_TOKENIZER: AutoTokenizer | None = None
_INTENT_MODEL: HybridClassifier | None = None
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
@@ -52,14 +50,13 @@ _INFORMATION_CONTENT_MODEL: SetFitModel | None = None
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
def get_connector_classifier_tokenizer() -> PreTrainedTokenizer:
def get_connector_classifier_tokenizer() -> AutoTokenizer:
global _CONNECTOR_CLASSIFIER_TOKENIZER
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer.
_CONNECTOR_CLASSIFIER_TOKENIZER = cast(
PreTrainedTokenizer,
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
_CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained(
"distilbert-base-uncased"
)
return _CONNECTOR_CLASSIFIER_TOKENIZER
@@ -95,15 +92,12 @@ def get_local_connector_classifier(
return _CONNECTOR_CLASSIFIER_MODEL
def get_intent_model_tokenizer() -> PreTrainedTokenizer:
def get_intent_model_tokenizer() -> AutoTokenizer:
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer.
_INTENT_TOKENIZER = cast(
PreTrainedTokenizer,
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
)
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased")
return _INTENT_TOKENIZER
@@ -401,9 +395,9 @@ def run_content_classification_inference(
def map_keywords(
input_ids: torch.Tensor, tokenizer: PreTrainedTokenizer, is_keyword: list[bool]
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
) -> list[str]:
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
tokens = tokenizer.convert_ids_to_tokens(input_ids)
if not len(tokens) == len(is_keyword):
raise ValueError("Length of tokens and keyword predictions must match")

View File

@@ -1,6 +1,5 @@
import json
import os
from typing import cast
import torch
import torch.nn as nn
@@ -14,14 +13,15 @@ class HybridClassifier(nn.Module):
super().__init__()
config = DistilBertConfig()
self.distilbert = DistilBertModel(config)
config = self.distilbert.config # type: ignore
# Keyword tokenwise binary classification layer
self.keyword_classifier = nn.Linear(config.dim, 2)
self.keyword_classifier = nn.Linear(self.distilbert.config.dim, 2)
# Intent Classifier layers
self.pre_classifier = nn.Linear(config.dim, config.dim)
self.intent_classifier = nn.Linear(config.dim, 2)
self.pre_classifier = nn.Linear(
self.distilbert.config.dim, self.distilbert.config.dim
)
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
self.device = torch.device("cpu")
@@ -30,7 +30,7 @@ class HybridClassifier(nn.Module):
query_ids: torch.Tensor,
query_mask: torch.Tensor,
) -> dict[str, torch.Tensor]:
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # type: ignore
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask)
sequence_output = outputs.last_hidden_state
# Intent classification on the CLS token
@@ -79,9 +79,8 @@ class ConnectorClassifier(nn.Module):
self.config = config
self.distilbert = DistilBertModel(config)
config = self.distilbert.config # type: ignore
self.connector_global_classifier = nn.Linear(config.dim, 1)
self.connector_match_classifier = nn.Linear(config.dim, 1)
self.connector_global_classifier = nn.Linear(self.distilbert.config.dim, 1)
self.connector_match_classifier = nn.Linear(self.distilbert.config.dim, 1)
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
# Token indicating end of connector name, and on which classifier is used
@@ -96,7 +95,7 @@ class ConnectorClassifier(nn.Module):
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.distilbert( # type: ignore
hidden_states = self.distilbert(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
@@ -115,10 +114,7 @@ class ConnectorClassifier(nn.Module):
@classmethod
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
config = cast(
DistilBertConfig,
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
)
config = DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json"))
device = (
torch.device("cuda")
if torch.cuda.is_available()

View File

@@ -1,18 +0,0 @@
from pydantic import BaseModel
class KGQuestionEntityExtractionResult(BaseModel):
entities: list[str]
terms: list[str]
time_filter: str | None
class KGQuestionRelationshipExtractionResult(BaseModel):
relationships: list[str]
class KGQuestionExtractionResult(BaseModel):
entities: list[str]
relationships: list[str]
terms: list[str]
time_filter: str | None

View File

@@ -37,7 +37,7 @@ def research_object_source(
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
search_tool = graph_config.tooling.search_tool
graph_config.inputs.search_request.query
question = graph_config.inputs.search_request.query
object, document_source = state.object_source_combination
if search_tool is None or graph_config.inputs.search_request.persona is None:

View File

@@ -17,9 +17,6 @@ def research(
search_tool: SearchTool,
document_sources: list[DocumentSource] | None = None,
time_cutoff: datetime | None = None,
kg_entities: list[str] | None = None,
kg_relationships: list[str] | None = None,
kg_terms: list[str] | None = None,
) -> list[LlmDoc]:
# new db session to avoid concurrency issues
@@ -36,9 +33,6 @@ def research(
skip_query_analysis=True,
document_sources=document_sources,
time_cutoff=time_cutoff,
kg_entities=kg_entities,
kg_relationships=kg_relationships,
kg_terms=kg_terms,
),
):
# get retrieved docs to send to the rest of the graph

View File

@@ -1,54 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from typing import Literal
from langgraph.types import Send
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
def simple_vs_search(
state: MainState,
) -> Literal["process_kg_only_answers", "construct_deep_search_filters"]:
if state.strategy == KGAnswerStrategy.DEEP or len(state.relationships) > 0:
return "construct_deep_search_filters"
else:
return "process_kg_only_answers"
def research_individual_object(
state: MainState,
# ) -> list[Send | Hashable] | Literal["individual_deep_search"]:
) -> list[Send | Hashable]:
edge_start_time = datetime.now()
# if (
# not state.div_con_entities
# or not state.broken_down_question
# or not state.vespa_filter_results
# ):
# return "individual_deep_search"
# else:
assert state.div_con_entities is not None
assert state.broken_down_question is not None
assert state.vespa_filter_results is not None
return [
Send(
"process_individual_deep_search",
ResearchObjectInput(
entity=entity,
broken_down_question=state.broken_down_question,
vespa_filter_results=state.vespa_filter_results,
source_division=state.source_division,
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for entity in state.div_con_entities
]

View File

@@ -1,138 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.kb_search.conditional_edges import (
research_individual_object,
)
from onyx.agents.agent_search.kb_search.conditional_edges import simple_vs_search
from onyx.agents.agent_search.kb_search.nodes.a1_extract_ert import extract_ert
from onyx.agents.agent_search.kb_search.nodes.a2_analyze import analyze
from onyx.agents.agent_search.kb_search.nodes.a3_generate_simple_sql import (
generate_simple_sql,
)
from onyx.agents.agent_search.kb_search.nodes.b1_construct_deep_search_filters import (
construct_deep_search_filters,
)
from onyx.agents.agent_search.kb_search.nodes.b2p_process_individual_deep_search import (
process_individual_deep_search,
)
from onyx.agents.agent_search.kb_search.nodes.b3_consoldidate_individual_deep_search import (
consoldidate_individual_deep_search,
)
from onyx.agents.agent_search.kb_search.nodes.c1_process_kg_only_answers import (
process_kg_only_answers,
)
from onyx.agents.agent_search.kb_search.nodes.d1_generate_answer import generate_answer
from onyx.agents.agent_search.kb_search.states import MainInput
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def kb_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for the knowledge graph search process.
"""
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
### Add nodes ###
graph.add_node(
"extract_ert",
extract_ert,
)
graph.add_node(
"generate_simple_sql",
generate_simple_sql,
)
graph.add_node(
"analyze",
analyze,
)
graph.add_node(
"generate_answer",
generate_answer,
)
graph.add_node(
"construct_deep_search_filters",
construct_deep_search_filters,
)
graph.add_node(
"process_individual_deep_search",
process_individual_deep_search,
)
# graph.add_node(
# "individual_deep_search",
# individual_deep_search,
# )
graph.add_node(
"consoldidate_individual_deep_search",
consoldidate_individual_deep_search,
)
graph.add_node("process_kg_only_answers", process_kg_only_answers)
### Add edges ###
graph.add_edge(start_key=START, end_key="extract_ert")
graph.add_edge(
start_key="extract_ert",
end_key="analyze",
)
graph.add_edge(
start_key="analyze",
end_key="generate_simple_sql",
)
graph.add_conditional_edges("generate_simple_sql", simple_vs_search)
graph.add_edge(start_key="process_kg_only_answers", end_key="generate_answer")
# graph.add_edge(
# start_key="construct_deep_search_filters",
# end_key="process_individual_deep_search",
# )
graph.add_conditional_edges(
source="construct_deep_search_filters",
path=research_individual_object,
path_map=["process_individual_deep_search"],
)
graph.add_edge(
start_key="process_individual_deep_search",
end_key="consoldidate_individual_deep_search",
)
# graph.add_edge(
# start_key="individual_deep_search",
# end_key="consoldidate_individual_deep_search",
# )
graph.add_edge(
start_key="consoldidate_individual_deep_search", end_key="generate_answer"
)
graph.add_edge(
start_key="generate_answer",
end_key=END,
)
return graph

View File

@@ -1,217 +0,0 @@
from typing import Set
from onyx.agents.agent_search.kb_search.models import KGExpendedGraphObjects
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.relationships import get_relationships_of_entity
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _check_entities_disconnected(
current_entities: list[str], current_relationships: list[str]
) -> bool:
"""
Check if all entities in current_entities are disconnected via the given relationships.
Relationships are in the format: source_entity__relationship_name__target_entity
Args:
current_entities: List of entity IDs to check connectivity for
current_relationships: List of relationships in format source__relationship__target
Returns:
bool: True if all entities are disconnected, False otherwise
"""
if not current_entities:
return True
# Create a graph representation using adjacency list
graph: dict[str, set[str]] = {entity: set() for entity in current_entities}
# Build the graph from relationships
for relationship in current_relationships:
try:
source, _, target = relationship.split("__")
if source in graph and target in graph:
graph[source].add(target)
graph[target].add(source) # Add reverse edge since graph is undirected
except ValueError:
raise ValueError(f"Invalid relationship format: {relationship}")
# Use BFS to check if all entities are connected
visited: set[str] = set()
start_entity = current_entities[0]
def _bfs(start: str) -> None:
queue = [start]
visited.add(start)
while queue:
current = queue.pop(0)
for neighbor in graph[current]:
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
# Start BFS from the first entity
_bfs(start_entity)
logger.debug(f"Number of visited entities: {len(visited)}")
# Check if all current_entities are in visited
return not all(entity in visited for entity in current_entities)
def create_minimal_connected_query_graph(
entities: list[str], relationships: list[str], max_depth: int = 2
) -> KGExpendedGraphObjects:
"""
Find the minimal subgraph that connects all input entities, using only general entities
(<entity_type>:*) as intermediate nodes. The subgraph will include only the relationships
necessary to connect all input entities through the shortest possible paths.
Args:
entities: Initial list of entity IDs
relationships: Initial list of relationships in format source__relationship__target
max_depth: Maximum depth to expand the graph (default: 2)
Returns:
KGExpendedGraphObjects containing expanded entities and relationships
"""
# Create copies of input lists to avoid modifying originals
expanded_entities = entities.copy()
expanded_relationships = relationships.copy()
# Keep track of original entities
original_entities = set(entities)
# Build initial graph from existing relationships
graph: dict[str, set[tuple[str, str]]] = {
entity: set() for entity in expanded_entities
}
for rel in relationships:
try:
source, rel_name, target = rel.split("__")
if source in graph and target in graph:
graph[source].add((target, rel_name))
graph[target].add((source, rel_name))
except ValueError:
continue
# For each depth level
counter = 0
while counter < max_depth:
# Find all connected components in the current graph
components = []
visited = set()
def dfs(node: str, component: set[str]) -> None:
visited.add(node)
component.add(node)
for neighbor, _ in graph.get(node, set()):
if neighbor not in visited:
dfs(neighbor, component)
# Find all components
for entity in expanded_entities:
if entity not in visited:
component: Set[str] = set()
dfs(entity, component)
components.append(component)
# If we only have one component, we're done
if len(components) == 1:
break
# Find the shortest path between any two components using general entities
shortest_path = None
shortest_path_length = float("inf")
for comp1 in components:
for comp2 in components:
if comp1 == comp2:
continue
# Try to find path between entities in different components
for entity1 in comp1:
if not any(e in original_entities for e in comp1):
continue
# entity1_type = entity1.split(":")[0]
with get_session_with_current_tenant() as db_session:
entity1_rels = get_relationships_of_entity(db_session, entity1)
for rel1 in entity1_rels:
try:
source1, rel_name1, target1 = rel1.split("__")
if source1 != entity1:
continue
target1_type = target1.split(":")[0]
general_target = f"{target1_type}:*"
# Try to find path from general_target to comp2
for entity2 in comp2:
if not any(e in original_entities for e in comp2):
continue
with get_session_with_current_tenant() as db_session:
entity2_rels = get_relationships_of_entity(
db_session, entity2
)
for rel2 in entity2_rels:
try:
source2, rel_name2, target2 = rel2.split("__")
if target2 != entity2:
continue
source2_type = source2.split(":")[0]
general_source = f"{source2_type}:*"
if general_target == general_source:
# Found a path of length 2
path = [
(entity1, rel_name1, general_target),
(general_target, rel_name2, entity2),
]
if len(path) < shortest_path_length:
shortest_path = path
shortest_path_length = len(path)
except ValueError:
continue
except ValueError:
continue
# If we found a path, add it to our graph
if shortest_path:
for source, rel_name, target in shortest_path:
# Add general entity if needed
if ":*" in source and source not in expanded_entities:
expanded_entities.append(source)
if ":*" in target and target not in expanded_entities:
expanded_entities.append(target)
# Add relationship
rel = f"{source}__{rel_name}__{target}"
if rel not in expanded_relationships:
expanded_relationships.append(rel)
# Update graph
if source not in graph:
graph[source] = set()
if target not in graph:
graph[target] = set()
graph[source].add((target, rel_name))
graph[target].add((source, rel_name))
counter += 1
logger.debug(f"Number of expanded entities: {len(expanded_entities)}")
logger.debug(f"Number of expanded relationships: {len(expanded_relationships)}")
return KGExpendedGraphObjects(
entities=expanded_entities, relationships=expanded_relationships
)

View File

@@ -1,34 +0,0 @@
from pydantic import BaseModel
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import YesNoEnum
class KGQuestionEntityExtractionResult(BaseModel):
entities: list[str]
terms: list[str]
time_filter: str | None
class KGAnswerApproach(BaseModel):
strategy: KGAnswerStrategy
format: KGAnswerFormat
broken_down_question: str | None = None
divide_and_conquer: YesNoEnum | None = None
class KGQuestionRelationshipExtractionResult(BaseModel):
relationships: list[str]
class KGQuestionExtractionResult(BaseModel):
entities: list[str]
relationships: list[str]
terms: list[str]
time_filter: str | None
class KGExpendedGraphObjects(BaseModel):
entities: list[str]
relationships: list[str]

View File

@@ -1,240 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
from onyx.agents.agent_search.kb_search.models import (
KGQuestionRelationshipExtractionResult,
)
from onyx.agents.agent_search.kb_search.states import (
ERTExtractionUpdate,
)
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.relationships import get_allowed_relationship_type_pairs
from onyx.kg.extractions.extraction_processing import get_entity_types_str
from onyx.kg.extractions.extraction_processing import get_relationship_types_str
from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def extract_ert(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ERTExtractionUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
# first four lines duplicates from generate_initial_answer
question = graph_config.inputs.search_request.query
today_date = datetime.now().strftime("%A, %Y-%m-%d")
all_entity_types = get_entity_types_str(active=None)
all_relationship_types = get_relationship_types_str(active=None)
### get the entities, terms, and filters
query_extraction_pre_prompt = QUERY_ENTITY_EXTRACTION_PROMPT.format(
entity_types=all_entity_types,
relationship_types=all_relationship_types,
)
query_extraction_prompt = query_extraction_pre_prompt.replace(
"---content---", question
).replace("---today_date---", today_date)
msg = [
HumanMessage(
content=query_extraction_prompt,
)
]
fast_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
15,
fast_llm.invoke,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
entity_extraction_result = (
KGQuestionEntityExtractionResult.model_validate_json(cleaned_response)
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
entity_extraction_result = KGQuestionEntityExtractionResult(
entities=[],
terms=[],
time_filter="",
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
entity_extraction_result = KGQuestionEntityExtractionResult(
entities=[],
terms=[],
time_filter="",
)
ert_entities_string = f"Entities: {entity_extraction_result.entities}\n"
# ert_terms_string = f"Terms: {entity_extraction_result.terms}"
# ert_time_filter_string = f"Time Filter: {entity_extraction_result.time_filter}\n"
### get the relationships
# find the relationship types that match the extracted entity types
with get_session_with_current_tenant() as db_session:
allowed_relationship_pairs = get_allowed_relationship_type_pairs(
db_session, entity_extraction_result.entities
)
query_relationship_extraction_prompt = (
QUERY_RELATIONSHIP_EXTRACTION_PROMPT.replace("---question---", question)
.replace("---today_date---", today_date)
.replace(
"---relationship_type_options---",
" - " + "\n - ".join(allowed_relationship_pairs),
)
.replace("---identified_entities---", ert_entities_string)
.replace("---entity_types---", all_entity_types)
)
msg = [
HumanMessage(
content=query_relationship_extraction_prompt,
)
]
fast_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
15,
fast_llm.invoke,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = cleaned_response.replace("{{", '{"')
cleaned_response = cleaned_response.replace("}}", '"}')
try:
relationship_extraction_result = (
KGQuestionRelationshipExtractionResult.model_validate_json(
cleaned_response
)
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
relationships=[],
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
relationships=[],
)
# ert_relationships_string = (
# f"Relationships: {relationship_extraction_result.relationships}\n"
# )
##
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=ert_entities_string,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=ert_relationships_string,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=ert_terms_string,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=ert_time_filter_string,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# dispatch_main_answer_stop_info(0, writer)
return ERTExtractionUpdate(
entities_types_str=all_entity_types,
entities=entity_extraction_result.entities,
relationships=relationship_extraction_result.relationships,
terms=entity_extraction_result.terms,
time_filter=entity_extraction_result.time_filter,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="extract entities terms",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,210 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import (
create_minimal_connected_query_graph,
)
from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import YesNoEnum
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.kg.clustering.normalizations import normalize_entities
from onyx.kg.clustering.normalizations import normalize_relationships
from onyx.kg.clustering.normalizations import normalize_terms
from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def analyze(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> AnalysisUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
entities = state.entities
relationships = state.relationships
terms = state.terms
time_filter = state.time_filter
normalized_entities = normalize_entities(entities)
normalized_relationships = normalize_relationships(
relationships, normalized_entities.entity_normalization_map
)
normalized_terms = normalize_terms(terms)
normalized_time_filter = time_filter
# Expand the entities and relationships to make sure that entities are connected
graph_expansion = create_minimal_connected_query_graph(
normalized_entities.entities,
normalized_relationships.relationships,
max_depth=2,
)
query_graph_entities = graph_expansion.entities
query_graph_relationships = graph_expansion.relationships
# Evaluate whether a search needs to be done after identifying all entities and relationships
strategy_generation_prompt = (
STRATEGY_GENERATION_PROMPT.replace(
"---entities---", "\n".join(query_graph_entities)
)
.replace("---relationships---", "\n".join(query_graph_relationships))
.replace("---terms---", "\n".join(normalized_terms.terms))
.replace("---question---", question)
)
msg = [
HumanMessage(
content=strategy_generation_prompt,
)
]
# fast_llm = graph_config.tooling.fast_llm
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
20,
# fast_llm.invoke,
primary_llm.invoke,
prompt=msg,
timeout_override=5,
max_tokens=100,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
approach_extraction_result = KGAnswerApproach.model_validate_json(
cleaned_response
)
strategy = approach_extraction_result.strategy
output_format = approach_extraction_result.format
broken_down_question = approach_extraction_result.broken_down_question
divide_and_conquer = approach_extraction_result.divide_and_conquer
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
strategy = KGAnswerStrategy.DEEP
output_format = KGAnswerFormat.TEXT
broken_down_question = None
divide_and_conquer = YesNoEnum.NO
if strategy is None or output_format is None:
raise ValueError(f"Invalid strategy: {cleaned_response}")
except Exception as e:
logger.error(f"Error in strategy generation: {e}")
raise e
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece="\n".join(normalized_entities.entities),
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece="\n".join(normalized_relationships.relationships),
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece="\n".join(query_graph_entities),
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece="\n".join(query_graph_relationships),
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=strategy.value,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=output_format.value,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# dispatch_main_answer_stop_info(0, writer)
return AnalysisUpdate(
normalized_core_entities=normalized_entities.entities,
normalized_core_relationships=normalized_relationships.relationships,
query_graph_entities=query_graph_entities,
query_graph_relationships=query_graph_relationships,
normalized_terms=normalized_terms.terms,
normalized_time_filter=normalized_time_filter,
strategy=strategy,
broken_down_question=broken_down_question,
output_format=output_format,
divide_and_conquer=divide_and_conquer,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="analyze",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,224 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy import text
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import SQLSimpleGenerationUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.engine import get_session_with_current_tenant
from onyx.llm.interfaces import LLM
from onyx.prompts.kg_prompts import SIMPLE_SQL_PROMPT
from onyx.prompts.kg_prompts import SQL_AGGREGATION_REMOVAL_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def _sql_is_aggregate_query(sql_statement: str) -> bool:
return any(
agg_func in sql_statement.upper()
for agg_func in ["COUNT(", "MAX(", "MIN(", "AVG(", "SUM("]
)
def _remove_aggregation(sql_statement: str, llm: LLM) -> str:
"""
Remove aggregate functions from the SQL statement.
"""
sql_aggregation_removal_prompt = SQL_AGGREGATION_REMOVAL_PROMPT.replace(
"---sql_statement---", sql_statement
)
msg = [
HumanMessage(
content=sql_aggregation_removal_prompt,
)
]
# Grader
try:
llm_response = run_with_timeout(
15,
llm.invoke,
prompt=msg,
timeout_override=25,
max_tokens=800,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
sql_statement = cleaned_response.split("SQL:")[1].strip()
sql_statement = sql_statement.split(";")[0].strip() + ";"
sql_statement = sql_statement.replace("sql", "").strip()
except Exception as e:
logger.error(f"Error in strategy generation: {e}")
raise e
return sql_statement
def generate_simple_sql(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> SQLSimpleGenerationUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
entities_types_str = state.entities_types_str
state.strategy
state.output_format
simple_sql_prompt = (
SIMPLE_SQL_PROMPT.replace("---entities_types---", entities_types_str)
.replace("---question---", question)
.replace("---query_entities---", "\n".join(state.query_graph_entities))
.replace(
"---query_relationships---", "\n".join(state.query_graph_relationships)
)
)
msg = [
HumanMessage(
content=simple_sql_prompt,
)
]
fast_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
15,
fast_llm.invoke,
prompt=msg,
timeout_override=25,
max_tokens=800,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
sql_statement = cleaned_response.split("SQL:")[1].strip()
sql_statement = sql_statement.split(";")[0].strip() + ";"
sql_statement = sql_statement.replace("sql", "").strip()
# reasoning = cleaned_response.split("SQL:")[0].strip()
except Exception as e:
logger.error(f"Error in strategy generation: {e}")
raise e
if _sql_is_aggregate_query(sql_statement):
individualized_sql_query = _remove_aggregation(sql_statement, llm=fast_llm)
else:
individualized_sql_query = None
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=reasoning,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=cleaned_response,
# level=0,
# level_question_num=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# CRITICAL: EXECUTION OF SQL NEEDS TO ME MADE SAFE FOR PRODUCTION
with get_session_with_current_tenant() as db_session:
try:
result = db_session.execute(text(sql_statement))
# Handle scalar results (like COUNT)
if sql_statement.upper().startswith("SELECT COUNT"):
scalar_result = result.scalar()
query_results = (
[{"count": int(scalar_result) - 1}]
if scalar_result is not None
else []
)
else:
# Handle regular row results
rows = result.fetchall()
query_results = [dict(row._mapping) for row in rows]
except Exception as e:
logger.error(f"Error executing SQL query: {e}")
raise e
if (
individualized_sql_query is not None
and individualized_sql_query != sql_statement
):
with get_session_with_current_tenant() as db_session:
try:
result = db_session.execute(text(individualized_sql_query))
# Handle scalar results (like COUNT)
if individualized_sql_query.upper().startswith("SELECT COUNT"):
scalar_result = result.scalar()
individualized_query_results = (
[{"count": int(scalar_result) - 1}]
if scalar_result is not None
else []
)
else:
# Handle regular row results
rows = result.fetchall()
individualized_query_results = [dict(row._mapping) for row in rows]
except Exception as e:
# No stopping here, the individualized SQL query is not mandatory
logger.error(f"Error executing Individualized SQL query: {e}")
individualized_query_results = None
else:
individualized_query_results = None
# write_custom_event(
# "initial_agent_answer",
# AgentAnswerPiece(
# answer_piece=str(query_results),
# level=0,
# answer_type="agent_level_answer",
# ),
# writer,
# )
# dispatch_main_answer_stop_info(0, writer)
logger.info(f"query_results: {query_results}")
return SQLSimpleGenerationUpdate(
sql_query=sql_statement,
query_results=query_results,
individualized_sql_query=individualized_sql_query,
individualized_query_results=individualized_query_results,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate simple sql",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,158 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
from onyx.agents.agent_search.kb_search.states import KGVespaFilterResults
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import get_entity_types_with_grounded_source_name
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def construct_deep_search_filters(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> DeepSearchFilterUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
entities_types_str = state.entities_types_str
entities = state.query_graph_entities
relationships = state.query_graph_relationships
simple_sql_query = state.sql_query
search_filter_construction_prompt = (
SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
"---entity_type_descriptions---",
entities_types_str,
)
.replace(
"---entity_filters---",
"\n".join(entities),
)
.replace(
"---relationship_filters---",
"\n".join(relationships),
)
.replace(
"---sql_query---",
simple_sql_query or "(no SQL generated)",
)
.replace(
"---question---",
question,
)
)
msg = [
HumanMessage(
content=search_filter_construction_prompt,
)
]
llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
15,
llm.invoke,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = cleaned_response.replace("{{", '{"')
cleaned_response = cleaned_response.replace("}}", '"}')
try:
vespa_filter_results = KGVespaFilterResults.model_validate_json(
cleaned_response
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
vespa_filter_results = KGVespaFilterResults(
entity_filters=[],
relationship_filters=[],
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
vespa_filter_results = KGVespaFilterResults(
entity_filters=[],
relationship_filters=[],
)
if (
state.individualized_query_results
and len(state.individualized_query_results) > 0
):
div_con_entities = [
x["id_name"]
for x in state.individualized_query_results
if x["id_name"] is not None and "*" not in x["id_name"]
]
elif state.query_results and len(state.query_results) > 0:
div_con_entities = [
x["id_name"]
for x in state.query_results
if x["id_name"] is not None and "*" not in x["id_name"]
]
else:
div_con_entities = []
div_con_entities = list(set(div_con_entities))
logger.info(f"div_con_entities: {div_con_entities}")
with get_session_with_current_tenant() as db_session:
double_grounded_entity_types = get_entity_types_with_grounded_source_name(
db_session
)
source_division = False
if div_con_entities:
for entity_type in double_grounded_entity_types:
if entity_type.grounded_source_name.lower() in div_con_entities[0].lower():
source_division = True
break
else:
raise ValueError("No div_con_entities found")
return DeepSearchFilterUpdate(
vespa_filter_results=vespa_filter_results,
div_con_entities=div_con_entities,
source_division=source_division,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="construct deep search filters",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,133 +0,0 @@
import copy
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
from onyx.agents.agent_search.kb_search.states import ResearchObjectUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.prompts.kg_prompts import KG_OBJECT_SOURCE_RESEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def process_individual_deep_search(
state: ResearchObjectInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ResearchObjectUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = state.broken_down_question
source_division = state.source_division
if not search_tool:
raise ValueError("search_tool is not provided")
object = state.entity.replace(":", ": ").lower()
if source_division:
extended_question = question
else:
extended_question = f"{question} in regards to {object}"
kg_entity_filters = copy.deepcopy(
state.vespa_filter_results.entity_filters + [state.entity]
)
kg_relationship_filters = copy.deepcopy(
state.vespa_filter_results.relationship_filters
)
logger.info("Research for object: " + object)
logger.info(f"kg_entity_filters: {kg_entity_filters}")
logger.info(f"kg_relationship_filters: {kg_relationship_filters}")
# Add random wait between 1-3 seconds
# time.sleep(random.uniform(0, 3))
retrieved_docs = research(
question=extended_question,
kg_entities=kg_entity_filters,
kg_relationships=kg_relationship_filters,
search_tool=search_tool,
)
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num + 1) + ":\n" + doc.content
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# Built prompt
datetime.now().strftime("%A, %Y-%m-%d")
kg_object_source_research_prompt = KG_OBJECT_SOURCE_RESEARCH_PROMPT.format(
question=extended_question,
document_text=document_texts,
)
# Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=kg_object_source_research_prompt,
reserved_str="",
),
)
]
# fast_llm = graph_config.tooling.fast_llm
primary_llm = graph_config.tooling.primary_llm
llm = primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
llm.invoke,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
object_research_results = str(llm_response.content).replace("```json\n", "")
except Exception as e:
raise ValueError(f"Error in research_object_source: {e}")
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
return ResearchObjectUpdate(
research_object_results=[
{
"object": object.replace(":", ": ").capitalize(),
"results": object_research_results,
}
],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="process individual deep search",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,125 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
from onyx.agents.agent_search.kb_search.states import KGVespaFilterResults
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def individual_deep_search(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> DeepSearchFilterUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
entities_types_str = state.entities_types_str
entities = state.query_graph_entities
relationships = state.query_graph_relationships
simple_sql_query = state.sql_query
search_filter_construction_prompt = (
SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
"---entity_type_descriptions---",
entities_types_str,
)
.replace(
"---entity_filters---",
"\n".join(entities),
)
.replace(
"---relationship_filters---",
"\n".join(relationships),
)
.replace(
"---sql_query---",
simple_sql_query or "(no SQL generated)",
)
.replace(
"---question---",
question,
)
)
msg = [
HumanMessage(
content=search_filter_construction_prompt,
)
]
llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
15,
llm.invoke,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = cleaned_response.replace("{{", '{"')
cleaned_response = cleaned_response.replace("}}", '"}')
try:
vespa_filter_results = KGVespaFilterResults.model_validate_json(
cleaned_response
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
vespa_filter_results = KGVespaFilterResults(
entity_filters=[],
relationship_filters=[],
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
vespa_filter_results = KGVespaFilterResults(
entity_filters=[],
relationship_filters=[],
)
if state.query_results:
div_con_entities = [
x["id_name"] for x in state.query_results if x["id_name"] is not None
]
else:
div_con_entities = []
return DeepSearchFilterUpdate(
vespa_filter_results=vespa_filter_results,
div_con_entities=div_con_entities,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="construct deep search filters",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,45 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def consoldidate_individual_deep_search(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ConsolidatedResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
state.entities_types_str
research_object_results = state.research_object_results
consolidated_research_object_results_str = "\n".join(
[f"{x['object']}: {x['results']}" for x in research_object_results]
)
return ConsolidatedResearchUpdate(
consolidated_research_object_results_str=consolidated_research_object_results_str,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate simple sql",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,107 +0,0 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import ReferenceResults
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.document import get_base_llm_doc_information
from onyx.db.engine import get_session_with_current_tenant
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _general_format(result: dict[str, Any]) -> str:
name = result.get("name")
entity_type_id_name = result.get("entity_type_id_name")
result.get("id_name")
if entity_type_id_name:
return f"{entity_type_id_name.capitalize()}: {name}"
else:
return f"{name}"
def _generate_reference_results(
individualized_query_results: list[dict[str, Any]]
) -> ReferenceResults:
"""
Generate reference results from the query results data string.
"""
citations: list[str] = []
general_entities = []
# get all entities that correspond to an Onu=yx document
document_ids: list[str] = [
cast(str, x.get("document_id"))
for x in individualized_query_results
if x.get("document_id")
]
with get_session_with_current_tenant() as session:
llm_doc_information_results = get_base_llm_doc_information(
session, document_ids
)
for llm_doc_information_result in llm_doc_information_results:
citations.append(llm_doc_information_result.center_chunk.semantic_identifier)
for result in individualized_query_results:
document_id: str | None = result.get("document_id")
if not document_id:
general_entities.append(_general_format(result))
return ReferenceResults(citations=citations, general_entities=general_entities)
def process_kg_only_answers(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ResultsDataUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
query_results = state.query_results
individualized_query_results = state.individualized_query_results
query_results_list = []
if query_results:
for query_result in query_results:
query_results_list.append(str(query_result).replace(":", ": ").capitalize())
else:
raise ValueError("No query results were found")
query_results_data_str = "\n".join(query_results_list)
if individualized_query_results:
reference_results = _generate_reference_results(individualized_query_results)
else:
reference_results = None
return ResultsDataUpdate(
query_results_data_str=query_results_data_str,
individualized_query_results_data_str="",
reference_results=reference_results,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="kg query results data processing",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,157 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.states import MainOutput
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_EXAMPLES_PROMPT
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def generate_answer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> MainOutput:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
state.entities_types_str
introductory_answer = state.query_results_data_str
state.reference_results
search_tool = graph_config.tooling.search_tool
if search_tool is None:
raise ValueError("Search tool is not set")
consolidated_research_object_results_str = (
state.consolidated_research_object_results_str
)
question = graph_config.inputs.search_request.query
output_format = state.output_format
if state.reference_results:
examples = (
state.reference_results.citations
or state.reference_results.general_entities
or []
)
research_results = "\n".join([f"- {example}" for example in examples])
elif consolidated_research_object_results_str:
research_results = consolidated_research_object_results_str
else:
research_results = ""
if research_results and introductory_answer:
output_format_prompt = (
OUTPUT_FORMAT_PROMPT.replace("---question---", question)
.replace("---introductory_answer---", introductory_answer)
.replace("---output_format---", str(output_format) if output_format else "")
.replace("---research_results---", research_results)
)
elif not research_results and introductory_answer:
output_format_prompt = (
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
.replace("---introductory_answer---", introductory_answer)
.replace("---output_format---", str(output_format) if output_format else "")
)
elif research_results and not introductory_answer:
output_format_prompt = (
OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT.replace("---question---", question)
.replace("---output_format---", str(output_format) if output_format else "")
.replace("---research_results---", research_results)
)
else:
raise ValueError("No research results or introductory answer provided")
msg = [
HumanMessage(
content=output_format_prompt,
)
]
fast_llm = graph_config.tooling.fast_llm
dispatch_timings: list[float] = []
response: list[str] = []
def stream_answer() -> list[str]:
for message in fast_llm.stream(
prompt=msg,
timeout_override=30,
max_tokens=1000,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
# logger.debug(f"Answer piece: {content}")
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response
try:
response = run_with_timeout(
30,
stream_answer,
)
except Exception as e:
raise ValueError(f"Could not generate the answer. Error {e}")
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_ANSWER,
level=0,
level_question_num=0,
)
write_custom_event("stream_finished", stop_event, writer)
dispatch_main_answer_stop_info(0, writer)
return MainOutput(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="query completed",
node_start_time=node_start_time,
)
],
)

View File

@@ -1,42 +0,0 @@
from datetime import datetime
from typing import cast
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
def research(
question: str,
search_tool: SearchTool,
document_sources: list[DocumentSource] | None = None,
time_cutoff: datetime | None = None,
) -> list[LlmDoc]:
# new db session to avoid concurrency issues
callback_container: list[list[InferenceSection]] = []
retrieved_docs: list[LlmDoc] = []
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=question,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
document_sources=document_sources,
time_cutoff=time_cutoff,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
break
return retrieved_docs

View File

@@ -1,127 +0,0 @@
from enum import Enum
from operator import add
from typing import Annotated
from typing import Any
from typing import Dict
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import ReferenceResults
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
class KGVespaFilterResults(BaseModel):
entity_filters: list[str]
relationship_filters: list[str]
class KGAnswerStrategy(Enum):
DEEP = "DEEP"
SIMPLE = "SIMPLE"
class KGAnswerFormat(Enum):
LIST = "LIST"
TEXT = "TEXT"
class YesNoEnum(str, Enum):
YES = "yes"
NO = "no"
class AnalysisUpdate(LoggerUpdate):
normalized_core_entities: list[str] = []
normalized_core_relationships: list[str] = []
query_graph_entities: list[str] = []
query_graph_relationships: list[str] = []
normalized_terms: list[str] = []
normalized_time_filter: str | None = None
strategy: KGAnswerStrategy | None = None
output_format: KGAnswerFormat | None = None
broken_down_question: str | None = None
divide_and_conquer: YesNoEnum | None = None
class SQLSimpleGenerationUpdate(LoggerUpdate):
sql_query: str | None = None
query_results: list[Dict[Any, Any]] | None = None
individualized_sql_query: str | None = None
individualized_query_results: list[Dict[Any, Any]] | None = None
class ConsolidatedResearchUpdate(LoggerUpdate):
consolidated_research_object_results_str: str | None = None
class DeepSearchFilterUpdate(LoggerUpdate):
vespa_filter_results: KGVespaFilterResults | None = None
div_con_entities: list[str] | None = None
source_division: bool | None = None
class ResearchObjectOutput(LoggerUpdate):
research_object_results: Annotated[list[dict[str, Any]], add] = []
class ERTExtractionUpdate(LoggerUpdate):
entities_types_str: str = ""
entities: list[str] = []
relationships: list[str] = []
terms: list[str] = []
time_filter: str | None = None
class ResultsDataUpdate(LoggerUpdate):
query_results_data_str: str | None = None
individualized_query_results_data_str: str | None = None
reference_results: ReferenceResults | None = None
class ResearchObjectUpdate(LoggerUpdate):
research_object_results: Annotated[list[dict[str, Any]], add] = []
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
ERTExtractionUpdate,
AnalysisUpdate,
SQLSimpleGenerationUpdate,
ResultsDataUpdate,
ResearchObjectOutput,
DeepSearchFilterUpdate,
ResearchObjectUpdate,
ConsolidatedResearchUpdate,
):
pass
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]
class ResearchObjectInput(LoggerUpdate):
entity: str
broken_down_question: str
vespa_filter_results: KGVespaFilterResults
source_division: bool | None

View File

@@ -1,8 +1,6 @@
from typing import cast
from uuid import uuid4
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import ToolCall
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
@@ -12,21 +10,13 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolChoice
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.configs.chat_configs import USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.llm.factory import get_default_llms
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
from onyx.tools.models import QueryExpansions
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
@@ -40,49 +30,6 @@ from shared_configs.model_server_models import Embedding
logger = setup_logger()
def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
# TODO: Add trimming logic
history_segments = []
for msg in prompt_builder.message_history:
if isinstance(msg, HumanMessage):
role = "User"
elif isinstance(msg, AIMessage):
role = "Assistant"
else:
continue
history_segments.append(f"{role}:\n {msg.content}\n\n")
return "\n".join(history_segments)
def _expand_query(
query: str,
expansion_type: QueryExpansionType,
prompt_builder: AnswerPromptBuilder,
) -> str:
history_str = _create_history_str(prompt_builder)
if history_str:
if expansion_type == QueryExpansionType.KEYWORD:
base_prompt = QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
else:
base_prompt = QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
expansion_prompt = base_prompt.format(question=query, history=history_str)
else:
if expansion_type == QueryExpansionType.KEYWORD:
base_prompt = QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
else:
base_prompt = QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
expansion_prompt = base_prompt.format(question=query)
msg = HumanMessage(content=expansion_prompt)
primary_llm, _ = get_default_llms()
response = primary_llm.invoke([msg])
rephrased_query: str = cast(str, response.content)
return rephrased_query
# TODO: break this out into an implementation function
# and a function that handles extracting the necessary fields
# from the state and config
@@ -105,16 +52,7 @@ def choose_tool(
embedding_thread: TimeoutThread[Embedding] | None = None
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
expanded_keyword_thread: TimeoutThread[str] | None = None
expanded_semantic_thread: TimeoutThread[str] | None = None
override_kwargs: SearchToolOverrideKwargs | None = None
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
llm = agent_config.tooling.primary_llm
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
if (
not agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
@@ -134,20 +72,11 @@ def choose_tool(
agent_config.inputs.search_request.query,
)
if USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH:
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
expanded_keyword_thread = run_in_background(
_expand_query,
agent_config.inputs.search_request.query,
QueryExpansionType.KEYWORD,
prompt_builder,
)
expanded_semantic_thread = run_in_background(
_expand_query,
agent_config.inputs.search_request.query,
QueryExpansionType.SEMANTIC,
prompt_builder,
)
llm = agent_config.tooling.primary_llm
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
structured_response_format = agent_config.inputs.structured_response_format
tools = [
@@ -280,19 +209,6 @@ def choose_tool(
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
if (
selected_tool.name == SearchTool._NAME
and expanded_keyword_thread
and expanded_semantic_thread
):
keyword_expansion = wait_on_background(expanded_keyword_thread)
semantic_expansion = wait_on_background(expanded_semantic_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.expanded_queries = QueryExpansions(
keywords_expansions=[keyword_expansion],
semantic_expansions=[semantic_expansion],
)
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,

View File

@@ -9,7 +9,6 @@ from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.chat.models import LlmDoc
from onyx.context.search.utils import dedupe_documents
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
@@ -51,16 +50,16 @@ def basic_use_tool_response(
final_search_results = []
initial_search_results = []
initial_search_document_ids: set[str] = set()
for yield_item in tool_call_responses:
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_search_results = cast(list[LlmDoc], yield_item.response)
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
search_response_summary = cast(SearchResponseSummary, yield_item.response)
# use same function from _handle_search_tool_response_summary
initial_search_results = [
section_to_llm_doc(section)
for section in dedupe_documents(search_response_summary.top_sections)[0]
]
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_document_ids:
initial_search_document_ids.add(section.center_chunk.document_id)
initial_search_results.append(section_to_llm_doc(section))
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@@ -12,16 +12,12 @@ from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
divide_and_conquer_graph_builder,
)
from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
from onyx.agents.agent_search.dc_search_analysis.graph_builder import dc_graph_builder
from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
from onyx.agents.agent_search.deep_search.main.graph_builder import (
main_graph_builder as main_graph_builder_a,
)
from onyx.agents.agent_search.deep_search.main.states import (
MainInput as MainInput,
)
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.chat.models import AgentAnswerPiece
@@ -90,7 +86,7 @@ def _parse_agent_event(
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
graph_input: BasicInput | MainInput | KBMainInput | DCMainInput,
graph_input: BasicInput | MainInput | DCMainInput,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
for event in compiled_graph.stream(
@@ -104,7 +100,7 @@ def manage_sync_streaming(
def run_graph(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
input: BasicInput | MainInput | KBMainInput | DCMainInput,
input: BasicInput | MainInput | DCMainInput,
) -> AnswerStream:
config.behavior.perform_initial_search_decomposition = (
INITIAL_SEARCH_DECOMPOSITION_ENABLED
@@ -154,15 +150,6 @@ def run_basic_graph(
return run_graph(compiled_graph, config, input)
def run_kb_graph(
config: GraphConfig,
) -> AnswerStream:
graph = kb_graph_builder()
compiled_graph = graph.compile()
input = KBMainInput(log_messages=[])
return run_graph(compiled_graph, config, input)
def run_dc_graph(
config: GraphConfig,
) -> AnswerStream:

View File

@@ -27,8 +27,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
logger = setup_logger()
def build_sub_question_answer_prompt(
question: str,

View File

@@ -1,4 +1,3 @@
from enum import Enum
from typing import Any
from pydantic import BaseModel
@@ -154,14 +153,3 @@ class AnswerGenerationDocuments(BaseModel):
BaseMessage_Content = str | list[str | dict[str, Any]]
class QueryExpansionType(Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
class ReferenceResults(BaseModel):
# citations: list[InferenceSection]
citations: list[str]
general_entities: list[str]

View File

@@ -26,7 +26,6 @@ from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import is_in_repeated_error_state
from onyx.background.celery.tasks.indexing.utils import should_index
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
@@ -55,12 +54,11 @@ from onyx.connectors.exceptions import ConnectorValidationError
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import set_cc_pair_repeated_error_state
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingMode
from onyx.db.enums import IndexingStatus
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.search_settings import get_active_search_settings_list
@@ -243,16 +241,6 @@ def monitor_ccpair_indexing_taskset(
if not payload:
return
# if the CC Pair is `SCHEDULED`, moved it to `INITIAL_INDEXING`. A CC Pair
# should only ever be `SCHEDULED` if it's a new connector.
cc_pair = get_connector_credential_pair_from_id(db_session, cc_pair_id)
if cc_pair is None:
raise RuntimeError(f"CC Pair {cc_pair_id} not found")
if cc_pair.status == ConnectorCredentialPairStatus.SCHEDULED:
cc_pair.status = ConnectorCredentialPairStatus.INITIAL_INDEXING
db_session.commit()
elapsed_started_str = None
if payload.started:
elapsed_started = datetime.now(timezone.utc) - payload.started
@@ -367,22 +355,6 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
# mark the CC Pair as `ACTIVE` if it's not already
if (
# it should never technically be in this state, but we'll handle it anyway
cc_pair.status == ConnectorCredentialPairStatus.SCHEDULED
or cc_pair.status == ConnectorCredentialPairStatus.INITIAL_INDEXING
):
cc_pair.status = ConnectorCredentialPairStatus.ACTIVE
db_session.commit()
# if the index attempt is successful, clear the repeated error state
if cc_pair.in_repeated_error_state:
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt and index_attempt.status.is_successful():
cc_pair.in_repeated_error_state = False
db_session.commit()
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_INDEXING,
@@ -469,21 +441,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
# mark CC Pairs that are repeatedly failing as in repeated error state
with get_session_with_current_tenant() as db_session:
current_search_settings = get_current_search_settings(db_session)
for cc_pair_id in cc_pair_ids:
if is_in_repeated_error_state(
cc_pair_id=cc_pair_id,
search_settings_id=current_search_settings.id,
db_session=db_session,
):
set_cc_pair_repeated_error_state(
db_session=db_session,
cc_pair_id=cc_pair_id,
in_repeated_error_state=True,
)
# kick off index attempts
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
@@ -523,8 +480,13 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
)
continue
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
if not should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
@@ -532,6 +494,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
task_logger.info(
f"check_for_indexing - Not indexing cc_pair_id: {cc_pair_id} "
f"search_settings={search_settings_instance.id}, "
f"last_attempt={last_attempt.id if last_attempt else None}, "
f"secondary_index_building={len(search_settings_list) > 1}"
)
continue
@@ -539,6 +502,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
task_logger.info(
f"check_for_indexing - Will index cc_pair_id: {cc_pair_id} "
f"search_settings={search_settings_instance.id}, "
f"last_attempt={last_attempt.id if last_attempt else None}, "
f"secondary_index_building={len(search_settings_list) > 1}"
)

View File

@@ -22,7 +22,6 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -32,8 +31,6 @@ from onyx.db.index_attempt import create_index_attempt
from onyx.db.index_attempt import delete_index_attempt
from onyx.db.index_attempt import get_all_index_attempts_by_status
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
@@ -47,8 +44,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE = 5
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
@@ -351,42 +346,9 @@ def validate_indexing_fences(
return
def is_in_repeated_error_state(
cc_pair_id: int, search_settings_id: int, db_session: Session
) -> bool:
"""Checks if the cc pair / search setting combination is in a repeated error state."""
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
raise RuntimeError(
f"is_in_repeated_error_state - could not find cc_pair with id={cc_pair_id}"
)
# if the connector doesn't have a refresh_freq, a single failed attempt is enough
number_of_failed_attempts_in_a_row_needed = (
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE
if cc_pair.connector.refresh_freq is not None
else 1
)
most_recent_index_attempts = get_recent_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
limit=number_of_failed_attempts_in_a_row_needed,
db_session=db_session,
)
return len(
most_recent_index_attempts
) >= number_of_failed_attempts_in_a_row_needed and all(
attempt.status == IndexingStatus.FAILED
for attempt in most_recent_index_attempts
)
def should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
secondary_index_building: bool,
db_session: Session,
@@ -400,16 +362,6 @@ def should_index(
Return True if we should try to index, False if not.
"""
connector = cc_pair.connector
last_index_attempt = get_last_attempt_for_cc_pair(
cc_pair_id=cc_pair.id,
search_settings_id=search_settings_instance.id,
db_session=db_session,
)
all_recent_errored = is_in_repeated_error_state(
cc_pair_id=cc_pair.id,
search_settings_id=search_settings_instance.id,
db_session=db_session,
)
# uncomment for debugging
# task_logger.info(f"_should_index: "
@@ -436,24 +388,24 @@ def should_index(
# When switching over models, always index at least once
if search_settings_instance.status == IndexModelStatus.FUTURE:
if last_index_attempt:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index_attempt.status == IndexingStatus.SUCCESS:
if last_index.status == IndexingStatus.SUCCESS:
# print(
# f"Not indexing cc_pair={cc_pair.id}: FUTURE model with successful last index attempt={last_index.id}"
# )
return False
# No new index if the last index attempt is waiting to start
if last_index_attempt.status == IndexingStatus.NOT_STARTED:
if last_index.status == IndexingStatus.NOT_STARTED:
# print(
# f"Not indexing cc_pair={cc_pair.id}: FUTURE model with NOT_STARTED last index attempt={last_index.id}"
# )
return False
# No new index if the last index attempt is running
if last_index_attempt.status == IndexingStatus.IN_PROGRESS:
if last_index.status == IndexingStatus.IN_PROGRESS:
# print(
# f"Not indexing cc_pair={cc_pair.id}: FUTURE model with IN_PROGRESS last index attempt={last_index.id}"
# )
@@ -487,27 +439,18 @@ def should_index(
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index_attempt:
if not last_index:
return True
if connector.refresh_freq is None:
# print(f"Not indexing cc_pair={cc_pair.id}: refresh_freq is None")
return False
# if in the "initial" phase, we should always try and kick-off indexing
# as soon as possible if there is no ongoing attempt. In other words,
# no delay UNLESS we're repeatedly failing to index.
if (
cc_pair.status == ConnectorCredentialPairStatus.INITIAL_INDEXING
and not all_recent_errored
):
return True
current_db_time = get_db_current_time(db_session)
time_since_index = current_db_time - last_index_attempt.time_updated
time_since_index = current_db_time - last_index.time_updated
if time_since_index.total_seconds() < connector.refresh_freq:
# print(
# f"Not indexing cc_pair={cc_pair.id}: Last index attempt={last_index_attempt.id} "
# f"Not indexing cc_pair={cc_pair.id}: Last index attempt={last_index.id} "
# f"too recent ({time_since_index.total_seconds()}s < {connector.refresh_freq}s)"
# )
return False

View File

@@ -11,8 +11,6 @@ from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_dc_graph
from onyx.agents.agent_search.run_graph import run_kb_graph
from onyx.agents.agent_search.run_graph import run_dc_graph
from onyx.agents.agent_search.run_graph import run_main_graph
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
@@ -147,21 +145,13 @@ class Answer:
if self.graph_config.behavior.use_agentic_search:
run_langgraph = run_main_graph
elif self.graph_config.inputs.search_request.persona:
if (
elif (
self.graph_config.inputs.search_request.persona
and self.graph_config.inputs.search_request.persona.description.startswith(
"DivCon Beta Agent"
"DivCon Beta Agent"
)
):
run_langgraph = run_dc_graph
elif self.graph_config.inputs.search_request.persona.name.startswith(
"KG Dev"
):
run_langgraph = run_kb_graph
else:
run_langgraph = run_basic_graph
):
run_langgraph = run_dc_graph
else:
run_langgraph = run_basic_graph

View File

@@ -13,7 +13,6 @@ from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import MessageType
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
@@ -101,39 +100,6 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
)
def inference_section_from_llm_doc(llm_doc: LlmDoc) -> InferenceSection:
# Create a center chunk first
center_chunk = InferenceChunk(
document_id=llm_doc.document_id,
chunk_id=0, # Default to 0 since LlmDoc doesn't have this info
content=llm_doc.content,
blurb=llm_doc.blurb,
semantic_identifier=llm_doc.semantic_identifier,
source_type=llm_doc.source_type,
metadata=llm_doc.metadata,
updated_at=llm_doc.updated_at,
source_links=llm_doc.source_links or {},
match_highlights=llm_doc.match_highlights or [],
section_continuation=False,
image_file_name=None,
title=None,
boost=1,
recency_bias=1.0,
score=None,
hidden=False,
doc_summary="",
chunk_context="",
)
# Create InferenceSection with the center chunk
# Since we don't have access to the original chunks, we'll use an empty list
return InferenceSection(
center_chunk=center_chunk,
chunks=[], # Original surrounding chunks are not available in LlmDoc
combined_content=llm_doc.content,
)
def combine_message_thread(
messages: list[ThreadMessage],
max_tokens: int | None,

View File

@@ -2,11 +2,9 @@ import time
import traceback
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from functools import partial
from typing import cast
from typing import Protocol
from uuid import UUID
from sqlalchemy.orm import Session
@@ -84,8 +82,6 @@ from onyx.db.engine import get_session_context_manager
from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists
from onyx.db.milestone import update_user_assistant_milestone
from onyx.db.models import ChatMessage
from onyx.db.models import Persona
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
@@ -100,8 +96,6 @@ from onyx.file_store.utils import load_all_chat_files
from onyx.file_store.utils import load_all_user_file_files
from onyx.file_store.utils import load_all_user_files
from onyx.file_store.utils import save_files
from onyx.kg.clustering.clustering import kg_clustering
from onyx.kg.extractions.extraction_processing import kg_extraction
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
@@ -165,25 +159,6 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled"
COMMON_TOOL_RESPONSE_TYPES = {
"image": ChatFileType.IMAGE,
"csv": ChatFileType.CSV,
}
class PartialResponse(Protocol):
def __call__(
self,
message: str,
rephrased_query: str | None,
reference_docs: list[DbSearchDoc] | None,
files: list[FileDescriptor],
token_count: int,
citations: dict[int, int] | None,
error: str | None,
tool_call: ToolCall | None,
) -> ChatMessage: ...
def _translate_citations(
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
@@ -236,25 +211,25 @@ def _handle_search_tool_response_summary(
reference_db_search_docs = selected_search_docs
doc_ids = {doc.id for doc in reference_db_search_docs}
if user_files is not None and loaded_user_files is not None:
if user_files is not None:
for user_file in user_files:
if user_file.id in doc_ids:
continue
associated_chat_file = next(
(
file
for file in loaded_user_files
if file.file_id == str(user_file.file_id)
),
None,
)
# Use create_search_doc_from_user_file to properly add the document to the database
if associated_chat_file is not None:
db_doc = create_search_doc_from_user_file(
user_file, associated_chat_file, db_session
)
reference_db_search_docs.append(db_doc)
if user_file.id not in doc_ids:
associated_chat_file = None
if loaded_user_files is not None:
associated_chat_file = next(
(
file
for file in loaded_user_files
if file.file_id == str(user_file.file_id)
),
None,
)
# Use create_search_doc_from_user_file to properly add the document to the database
if associated_chat_file is not None:
db_doc = create_search_doc_from_user_file(
user_file, associated_chat_file, db_session
)
reference_db_search_docs.append(db_doc)
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
@@ -382,86 +357,6 @@ def _get_force_search_settings(
)
def _get_user_knowledge_files(
info: AnswerPostInfo,
user_files: list[InMemoryChatFile],
file_id_to_user_file: dict[str, InMemoryChatFile],
) -> Generator[UserKnowledgeFilePacket, None, None]:
if not info.qa_docs_response:
return
logger.info(
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
)
# Extract document order from search results
doc_order = []
for doc in info.qa_docs_response.top_documents:
doc_id = doc.document_id
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
if file_id in file_id_to_user_file:
doc_order.append(file_id)
logger.info(f"ORDERING: Found {len(doc_order)} files from search results")
# Add any files that weren't in search results at the end
missing_files = [
f_id for f_id in file_id_to_user_file.keys() if f_id not in doc_order
]
missing_files.extend(doc_order)
doc_order = missing_files
logger.info(f"ORDERING: Added {len(missing_files)} missing files to the end")
# Reorder user files based on search results
ordered_user_files = [
file_id_to_user_file[f_id] for f_id in doc_order if f_id in file_id_to_user_file
]
yield UserKnowledgeFilePacket(
user_files=[
FileDescriptor(
id=str(file.file_id),
type=ChatFileType.USER_KNOWLEDGE,
)
for file in ordered_user_files
]
)
def _get_persona_for_chat_session(
new_msg_req: CreateChatMessageRequest,
user: User | None,
db_session: Session,
default_persona: Persona,
) -> Persona:
if new_msg_req.alternate_assistant_id is not None:
# Allows users to specify a temporary persona (assistant) in the chat session
# this takes highest priority since it's user specified
persona = get_persona_by_id(
new_msg_req.alternate_assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
elif new_msg_req.persona_override_config:
# Certain endpoints allow users to specify arbitrary persona settings
# this should never conflict with the alternate_assistant_id
persona = create_temporary_persona(
db_session=db_session,
persona_config=new_msg_req.persona_override_config,
user=user,
)
else:
persona = default_persona
if not persona:
raise RuntimeError("No persona specified or found for chat session")
return persona
ChatPacket = (
StreamingError
| QADocsResponse
@@ -483,149 +378,6 @@ ChatPacket = (
ChatPacketStream = Iterator[ChatPacket]
def _process_tool_response(
packet: ToolResponse,
db_session: Session,
selected_db_search_docs: list[DbSearchDoc] | None,
info_by_subq: dict[SubQuestionKey, AnswerPostInfo],
retrieval_options: RetrievalDetails | None,
user_file_files: list[UserFile] | None,
user_files: list[InMemoryChatFile] | None,
file_id_to_user_file: dict[str, InMemoryChatFile],
search_for_ordering_only: bool,
) -> Generator[ChatPacket, None, dict[SubQuestionKey, AnswerPostInfo]]:
level, level_question_num = (
(packet.level, packet.level_question_num)
if isinstance(packet, ExtendedToolResponse)
else BASIC_KEY
)
assert level is not None
assert level_question_num is not None
info = info_by_subq[SubQuestionKey(level=level, question_num=level_question_num)]
# Skip LLM relevance processing entirely for ordering-only mode
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
logger.info(
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
)
# Skip this packet entirely since it would trigger LLM processing
return info_by_subq
# TODO: don't need to dedupe here when we do it in agent flow
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
if search_for_ordering_only:
logger.info(
"Fast path: Skipping document deduplication for ordering-only mode"
)
(
info.qa_docs_response,
info.reference_db_search_docs,
info.dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
# Skip deduping completely for ordering-only mode to save time
dedupe_docs=bool(
not search_for_ordering_only
and retrieval_options
and retrieval_options.dedupe_docs
),
user_files=user_file_files if search_for_ordering_only else [],
loaded_user_files=(user_files if search_for_ordering_only else []),
)
# If we're using search just for ordering user files
if search_for_ordering_only and user_files:
yield from _get_user_knowledge_files(
info=info,
user_files=user_files,
file_id_to_user_file=file_id_to_user_file,
)
yield info.qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if search_for_ordering_only:
logger.info(
"Performance: Skipping relevance filtering for ordering-only mode"
)
return info_by_subq
if info.reference_db_search_docs is None:
logger.warning("No reference docs found for relevance filtering")
return info_by_subq
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in info.reference_db_search_docs
],
)
if info.dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=info.reference_db_search_docs,
dropped_indices=info.dropped_indices,
)
yield LLMRelevanceFilterResponse(llm_selected_doc_indices=llm_indices)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(final_context_docs=packet.response)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(list[ImageGenerationResponse], packet.response)
file_ids = save_files(
urls=[img.url for img in img_generation_response if img.url],
base64_files=[
img.image_data for img in img_generation_response if img.image_data
],
)
info.ai_message_files.extend(
[
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
)
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
info.qa_docs_response,
info.reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield info.qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
response_type = custom_tool_response.response_type
if response_type in COMMON_TOOL_RESPONSE_TYPES:
file_ids = custom_tool_response.tool_result.file_ids
file_type = COMMON_TOOL_RESPONSE_TYPES[response_type]
info.ai_message_files.extend(
[
FileDescriptor(id=str(file_id), type=file_type)
for file_id in file_ids
]
)
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
else:
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
return info_by_subq
def stream_chat_message_objects(
new_msg_req: CreateChatMessageRequest,
user: User | None,
@@ -666,20 +418,10 @@ def stream_chat_message_objects(
llm: LLM
index_str = "danswer_chunk_text_embedding_3_small"
if new_msg_req.message == "ee":
kg_extraction(tenant_id, index_str)
raise Exception("Extractions done")
elif new_msg_req.message == "cc":
kg_clustering(tenant_id, index_str)
raise Exception("Clustering done")
try:
# Move these variables inside the try block
file_id_to_user_file = {}
ordered_user_files = None
user_id = user.id if user is not None else None
@@ -694,19 +436,35 @@ def stream_chat_message_objects(
parent_id = new_msg_req.parent_message_id
reference_doc_ids = new_msg_req.search_doc_ids
retrieval_options = new_msg_req.retrieval_options
new_msg_req.alternate_assistant_id
alternate_assistant_id = new_msg_req.alternate_assistant_id
# permanent "log" store, used primarily for debugging
long_term_logger = LongTermLogger(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
)
persona = _get_persona_for_chat_session(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
default_persona=chat_session.persona,
)
if alternate_assistant_id is not None:
# Allows users to specify a temporary persona (assistant) in the chat session
# this takes highest priority since it's user specified
persona = get_persona_by_id(
alternate_assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
elif new_msg_req.persona_override_config:
# Certain endpoints allow users to specify arbitrary persona settings
# this should never conflict with the alternate_assistant_id
persona = persona = create_temporary_persona(
db_session=db_session,
persona_config=new_msg_req.persona_override_config,
user=user,
)
else:
persona = chat_session.persona
if not persona:
raise RuntimeError("No persona specified or found for chat session")
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
user=user,
@@ -988,42 +746,31 @@ def stream_chat_message_objects(
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
)
def create_response(
message: str,
rephrased_query: str | None,
reference_docs: list[DbSearchDoc] | None,
files: list[FileDescriptor],
token_count: int,
citations: dict[int, int] | None,
error: str | None,
tool_call: ToolCall | None,
) -> ChatMessage:
return create_new_chat_message(
chat_session_id=chat_session_id,
parent_message=(
final_msg
if existing_assistant_message_id is None
else parent_message
),
prompt_id=prompt_id,
overridden_model=overridden_model,
message=message,
rephrased_query=rephrased_query,
token_count=token_count,
message_type=MessageType.ASSISTANT,
alternate_assistant_id=new_msg_req.alternate_assistant_id,
error=error,
reference_docs=reference_docs,
files=files,
citations=citations,
tool_call=tool_call,
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
is_agentic=new_msg_req.use_agentic_search,
)
partial_response = create_response
# Cannot determine these without the LLM step or breaking out early
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
# if we're using an existing assistant message, then this will just be an
# update operation, in which case the parent should be the parent of
# the latest. If we're creating a new assistant message, then the parent
# should be the latest message (latest user message)
parent_message=(
final_msg if existing_assistant_message_id is None else parent_message
),
prompt_id=prompt_id,
overridden_model=overridden_model,
# message=,
# rephrased_query=,
# token_count=,
message_type=MessageType.ASSISTANT,
alternate_assistant_id=new_msg_req.alternate_assistant_id,
# error=,
# reference_docs=,
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
is_agentic=new_msg_req.use_agentic_search,
)
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
if new_msg_req.persona_override_config:
@@ -1294,23 +1041,220 @@ def stream_chat_message_objects(
use_agentic_search=new_msg_req.use_agentic_search,
)
# reference_db_search_docs = None
# qa_docs_response = None
# # any files to associate with the AI message e.g. dall-e generated images
# ai_message_files = []
# dropped_indices = None
# tool_result = None
# TODO: different channels for stored info when it's coming from the agent flow
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
lambda: AnswerPostInfo(ai_message_files=[])
)
refined_answer_improvement = True
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
info_by_subq = yield from _process_tool_response(
packet=packet,
db_session=db_session,
selected_db_search_docs=selected_db_search_docs,
info_by_subq=info_by_subq,
retrieval_options=retrieval_options,
user_file_files=user_file_files,
user_files=user_files,
file_id_to_user_file=file_id_to_user_file,
search_for_ordering_only=search_for_ordering_only,
level, level_question_num = (
(packet.level, packet.level_question_num)
if isinstance(packet, ExtendedToolResponse)
else BASIC_KEY
)
assert level is not None
assert level_question_num is not None
info = info_by_subq[
SubQuestionKey(level=level, question_num=level_question_num)
]
# Skip LLM relevance processing entirely for ordering-only mode
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
logger.info(
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
)
# Skip this packet entirely since it would trigger LLM processing
continue
# TODO: don't need to dedupe here when we do it in agent flow
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
if search_for_ordering_only:
logger.info(
"Fast path: Skipping document deduplication for ordering-only mode"
)
(
info.qa_docs_response,
info.reference_db_search_docs,
info.dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
# Skip deduping completely for ordering-only mode to save time
dedupe_docs=bool(
not search_for_ordering_only
and retrieval_options
and retrieval_options.dedupe_docs
),
user_files=user_file_files if search_for_ordering_only else [],
loaded_user_files=(
user_files if search_for_ordering_only else []
),
)
# If we're using search just for ordering user files
if (
search_for_ordering_only
and user_files
and info.qa_docs_response
):
logger.info(
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
)
# Extract document order from search results
doc_order = []
for doc in info.qa_docs_response.top_documents:
doc_id = doc.document_id
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
if file_id in file_id_to_user_file:
doc_order.append(file_id)
logger.info(
f"ORDERING: Found {len(doc_order)} files from search results"
)
# Add any files that weren't in search results at the end
missing_files = [
f_id
for f_id in file_id_to_user_file.keys()
if f_id not in doc_order
]
missing_files.extend(doc_order)
doc_order = missing_files
logger.info(
f"ORDERING: Added {len(missing_files)} missing files to the end"
)
# Reorder user files based on search results
ordered_user_files = [
file_id_to_user_file[f_id]
for f_id in doc_order
if f_id in file_id_to_user_file
]
yield UserKnowledgeFilePacket(
user_files=[
FileDescriptor(
id=str(file.file_id),
type=ChatFileType.USER_KNOWLEDGE,
)
for file in ordered_user_files
]
)
yield info.qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if search_for_ordering_only:
logger.info(
"Performance: Skipping relevance filtering for ordering-only mode"
)
continue
if info.reference_db_search_docs is None:
logger.warning(
"No reference docs found for relevance filtering"
)
continue
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in info.reference_db_search_docs
],
)
if info.dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=info.reference_db_search_docs,
dropped_indices=info.dropped_indices,
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
)
file_ids = save_files(
urls=[img.url for img in img_generation_response if img.url],
base64_files=[
img.image_data
for img in img_generation_response
if img.image_data
],
)
info.ai_message_files.extend(
[
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
)
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
info.qa_docs_response,
info.reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield info.qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
if (
custom_tool_response.response_type == "image"
or custom_tool_response.response_type == "csv"
):
file_ids = custom_tool_response.tool_result.file_ids
info.ai_message_files.extend(
[
FileDescriptor(
id=str(file_id),
type=(
ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
else ChatFileType.CSV
),
)
for file_id in file_ids
]
)
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
else:
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif isinstance(packet, StreamStopInfo):
if packet.stop_reason == StreamStopReason.FINISHED:
@@ -1347,46 +1291,22 @@ def stream_chat_message_objects(
if isinstance(e, ToolCallException):
yield StreamingError(error=error_msg, stack_trace=stack_trace)
elif llm:
client_error_msg = litellm_exception_to_error_msg(e, llm)
if llm.config.api_key and len(llm.config.api_key) > 2:
client_error_msg = client_error_msg.replace(
llm.config.api_key, "[REDACTED_API_KEY]"
)
stack_trace = stack_trace.replace(
llm.config.api_key, "[REDACTED_API_KEY]"
)
else:
if llm:
client_error_msg = litellm_exception_to_error_msg(e, llm)
if llm.config.api_key and len(llm.config.api_key) > 2:
error_msg = error_msg.replace(
llm.config.api_key, "[REDACTED_API_KEY]"
)
stack_trace = stack_trace.replace(
llm.config.api_key, "[REDACTED_API_KEY]"
)
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
db_session.rollback()
return
yield from _post_llm_answer_processing(
answer=answer,
info_by_subq=info_by_subq,
tool_dict=tool_dict,
partial_response=partial_response,
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
db_session=db_session,
chat_session_id=chat_session_id,
refined_answer_improvement=refined_answer_improvement,
)
def _post_llm_answer_processing(
answer: Answer,
info_by_subq: dict[SubQuestionKey, AnswerPostInfo],
tool_dict: dict[int, list[Tool]],
partial_response: PartialResponse,
llm_tokenizer_encode_func: Callable[[str], list[int]],
db_session: Session,
chat_session_id: UUID,
refined_answer_improvement: bool | None,
) -> Generator[ChatPacket, None, None]:
"""
Stores messages in the db and yields some final packets to the frontend
"""
# Post-LLM answer processing
try:
tool_name_to_tool_id: dict[str, int] = {}

View File

@@ -483,6 +483,14 @@ CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get(
DISABLE_INDEX_UPDATE_ON_SWAP = (
os.environ.get("DISABLE_INDEX_UPDATE_ON_SWAP", "").lower() == "true"
)
# Controls how many worker processes we spin up to index documents in the
# background. This is useful for speeding up indexing, but does require a
# fairly large amount of memory in order to increase substantially, since
# each worker loads the embedding models into memory.
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
NUM_SECONDARY_INDEXING_WORKERS = int(
os.environ.get("NUM_SECONDARY_INDEXING_WORKERS") or NUM_INDEXING_WORKERS
)
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
ENABLE_MULTIPASS_INDEXING = (
os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true"

View File

@@ -96,9 +96,3 @@ BING_API_KEY = os.environ.get("BING_API_KEY") or None
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)
# Whether or not to use the semantic & keyword search expansions for Basic Search
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH = (
os.environ.get("USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH", "false").lower()
== "true"
)

View File

@@ -1,12 +0,0 @@
import json
import os
KG_OWN_EMAIL_DOMAINS: list[str] = json.loads(
os.environ.get("KG_OWN_EMAIL_DOMAINS", "[]")
) # must be list
KG_IGNORE_EMAIL_DOMAINS: list[str] = json.loads(
os.environ.get("KG_IGNORE_EMAIL_DOMAINS", "[]")
) # must be list
KG_OWN_COMPANY: str = os.environ.get("KG_OWN_COMPANY", "")

View File

@@ -101,7 +101,6 @@ class ConfluenceConnector(
self.labels_to_skip = labels_to_skip
self.timezone_offset = timezone_offset
self._confluence_client: OnyxConfluence | None = None
self._low_timeout_confluence_client: OnyxConfluence | None = None
self._fetched_titles: set[str] = set()
self.allow_images = False
@@ -157,12 +156,6 @@ class ConfluenceConnector(
raise ConnectorMissingCredentialError("Confluence")
return self._confluence_client
@property
def low_timeout_confluence_client(self) -> OnyxConfluence:
if self._low_timeout_confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
return self._low_timeout_confluence_client
def set_credentials_provider(
self, credentials_provider: CredentialsProviderInterface
) -> None:
@@ -170,27 +163,13 @@ class ConfluenceConnector(
# raises exception if there's a problem
confluence_client = OnyxConfluence(
is_cloud=self.is_cloud,
url=self.wiki_base,
credentials_provider=credentials_provider,
self.is_cloud, self.wiki_base, credentials_provider
)
confluence_client._probe_connection(**self.probe_kwargs)
confluence_client._initialize_connection(**self.final_kwargs)
self._confluence_client = confluence_client
# create a low timeout confluence client for sync flows
low_timeout_confluence_client = OnyxConfluence(
is_cloud=self.is_cloud,
url=self.wiki_base,
credentials_provider=credentials_provider,
timeout=3,
)
low_timeout_confluence_client._probe_connection(**self.probe_kwargs)
low_timeout_confluence_client._initialize_connection(**self.final_kwargs)
self._low_timeout_confluence_client = low_timeout_confluence_client
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
raise NotImplementedError("Use set_credentials_provider with this connector.")
@@ -542,8 +521,11 @@ class ConfluenceConnector(
yield doc_metadata_list
def validate_connector_settings(self) -> None:
if self._confluence_client is None:
raise ConnectorMissingCredentialError("Confluence credentials not loaded.")
try:
spaces = self.low_timeout_confluence_client.get_all_spaces(limit=1)
spaces = self._confluence_client.get_all_spaces(limit=1)
except HTTPError as e:
status_code = e.response.status_code if e.response else None
if status_code == 401:

View File

@@ -72,14 +72,12 @@ class OnyxConfluence:
CREDENTIAL_PREFIX = "connector:confluence:credential"
CREDENTIAL_TTL = 300 # 5 min
PROBE_TIMEOUT = 5 # 5 seconds
def __init__(
self,
is_cloud: bool,
url: str,
credentials_provider: CredentialsProviderInterface,
timeout: int | None = None,
) -> None:
self._is_cloud = is_cloud
self._url = url.rstrip("/")
@@ -102,13 +100,11 @@ class OnyxConfluence:
self._kwargs: Any = None
self.shared_base_kwargs: dict[str, str | int | bool] = {
self.shared_base_kwargs = {
"api_version": "cloud" if is_cloud else "latest",
"backoff_and_retry": True,
"cloud": is_cloud,
}
if timeout:
self.shared_base_kwargs["timeout"] = timeout
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
"""credential_json - the current json credentials
@@ -195,8 +191,6 @@ class OnyxConfluence:
**kwargs: Any,
) -> None:
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
# add special timeout to make sure that we don't hang indefinitely
merged_kwargs["timeout"] = self.PROBE_TIMEOUT
with self._credentials_provider:
credentials, _ = self._renew_credentials()

View File

@@ -18,17 +18,11 @@ from onyx.indexing.models import IndexingSetting
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
MAX_METRICS_CONTENT = (
200 # Just need enough characters to identify where in the doc the chunk is
)
class QueryExpansions(BaseModel):
keywords_expansions: list[str] | None = None
semantic_expansions: list[str] | None = None
class RerankingDetails(BaseModel):
# If model is None (or num_rerank is 0), then reranking is turned off
rerank_model_name: str | None
@@ -113,9 +107,6 @@ class BaseFilters(BaseModel):
tags: list[Tag] | None = None
user_file_ids: list[int] | None = None
user_folder_ids: list[int] | None = None
kg_entities: list[str] | None = None
kg_relationships: list[str] | None = None
kg_terms: list[str] | None = None
class IndexFilters(BaseFilters):
@@ -148,8 +139,6 @@ class ChunkContext(BaseModel):
class SearchRequest(ChunkContext):
query: str
expanded_queries: QueryExpansions | None = None
search_type: SearchType = SearchType.SEMANTIC
human_selected_filters: BaseFilters | None = None
@@ -198,8 +187,6 @@ class SearchQuery(ChunkContext):
precomputed_query_embedding: Embedding | None = None
expanded_queries: QueryExpansions | None = None
class RetrievalDetails(ChunkContext):
# Use LLM to determine whether to do a retrieval or only rely on existing history

View File

@@ -68,7 +68,6 @@ class SearchPipeline:
):
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
self.search_request = search_request
self.user = user
self.llm = llm

View File

@@ -20,7 +20,7 @@ from onyx.context.search.models import SearchRequest
from onyx.context.search.preprocessing.access_filters import (
build_access_filters_for_user,
)
from onyx.context.search.utils import (
from onyx.context.search.retrieval.search_runner import (
remove_stop_words_and_punctuation,
)
from onyx.db.models import User
@@ -36,6 +36,7 @@ from onyx.utils.timing import log_function_time
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -182,9 +183,6 @@ def retrieval_preprocessing(
tags=preset_filters.tags, # Tags are never auto-extracted
access_control_list=user_acl_filters,
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
kg_entities=preset_filters.kg_entities,
kg_relationships=preset_filters.kg_relationships,
kg_terms=preset_filters.kg_terms,
)
llm_evaluation_type = LLMEvaluationType.BASIC
@@ -266,5 +264,4 @@ def retrieval_preprocessing(
chunks_below=chunks_below,
full_doc=search_request.full_doc,
precomputed_query_embedding=search_request.precomputed_query_embedding,
expanded_queries=search_request.expanded_queries,
)

View File

@@ -2,10 +2,10 @@ import string
from collections.abc import Callable
import nltk # type:ignore
from nltk.corpus import stopwords # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sqlalchemy.orm import Session
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
from onyx.context.search.enums import SearchType
from onyx.context.search.models import ChunkMetric
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
@@ -15,8 +15,6 @@ from onyx.context.search.models import MAX_METRICS_CONTENT
from onyx.context.search.models import RetrievalMetricsContainer
from onyx.context.search.models import SearchQuery
from onyx.context.search.postprocessing.postprocessing import cleanup_chunks
from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA
from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA_KEYWORD
from onyx.context.search.utils import inference_section_from_chunks
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_multilingual_expansion
@@ -29,9 +27,6 @@ from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.secondary_llm_flows.query_expansion import multilingual_query_expansion
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import TimeoutThread
from onyx.utils.threadpool_concurrency import wait_on_background
from onyx.utils.timing import log_function_time
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
@@ -41,23 +36,6 @@ from shared_configs.model_server_models import Embedding
logger = setup_logger()
def _dedupe_chunks(
chunks: list[InferenceChunkUncleaned],
) -> list[InferenceChunkUncleaned]:
used_chunks: dict[tuple[str, int], InferenceChunkUncleaned] = {}
for chunk in chunks:
key = (chunk.document_id, chunk.chunk_id)
if key not in used_chunks:
used_chunks[key] = chunk
else:
stored_chunk_score = used_chunks[key].score or 0
this_chunk_score = chunk.score or 0
if stored_chunk_score < this_chunk_score:
used_chunks[key] = chunk
return list(used_chunks.values())
def download_nltk_data() -> None:
resources = {
"stopwords": "corpora/stopwords",
@@ -91,6 +69,22 @@ def lemmatize_text(keywords: list[str]) -> list[str]:
# return keywords
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
try:
# Re-tokenize using the NLTK tokenizer for better matching
query = " ".join(keywords)
stop_words = set(stopwords.words("english"))
word_tokens = word_tokenize(query)
text_trimmed = [
word
for word in word_tokens
if (word.casefold() not in stop_words and word not in string.punctuation)
]
return text_trimmed or word_tokens
except Exception:
return keywords
def combine_retrieval_results(
chunk_sets: list[list[InferenceChunk]],
) -> list[InferenceChunk]:
@@ -129,20 +123,6 @@ def get_query_embedding(query: str, db_session: Session) -> Embedding:
return query_embedding
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
return query_embedding
@log_function_time(print_only=True)
def doc_index_retrieval(
query: SearchQuery,
@@ -159,113 +139,17 @@ def doc_index_retrieval(
query.query, db_session
)
keyword_embeddings_thread: TimeoutThread[list[Embedding]] | None = None
semantic_embeddings_thread: TimeoutThread[list[Embedding]] | None = None
top_base_chunks_thread: TimeoutThread[list[InferenceChunkUncleaned]] | None = None
top_semantic_chunks_thread: TimeoutThread[list[InferenceChunkUncleaned]] | None = (
None
top_chunks = document_index.hybrid_retrieval(
query=query.query,
query_embedding=query_embedding,
final_keywords=query.processed_keywords,
filters=query.filters,
hybrid_alpha=query.hybrid_alpha,
time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits,
offset=query.offset,
)
keyword_embeddings: list[Embedding] | None = None
semantic_embeddings: list[Embedding] | None = None
top_semantic_chunks: list[InferenceChunkUncleaned] | None = None
# original retrieveal method
top_base_chunks_thread = run_in_background(
document_index.hybrid_retrieval,
query.query,
query_embedding,
query.processed_keywords,
query.filters,
query.hybrid_alpha,
query.recency_bias_multiplier,
query.num_hits,
"semantic",
query.offset,
)
if (
query.expanded_queries
and query.expanded_queries.keywords_expansions
and query.expanded_queries.semantic_expansions
):
keyword_embeddings_thread = run_in_background(
get_query_embeddings,
query.expanded_queries.keywords_expansions,
db_session,
)
if query.search_type == SearchType.SEMANTIC:
semantic_embeddings_thread = run_in_background(
get_query_embeddings,
query.expanded_queries.semantic_expansions,
db_session,
)
keyword_embeddings = wait_on_background(keyword_embeddings_thread)
if query.search_type == SearchType.SEMANTIC:
assert semantic_embeddings_thread is not None
semantic_embeddings = wait_on_background(semantic_embeddings_thread)
# Use original query embedding for keyword retrieval embedding
keyword_embeddings = [query_embedding]
# Note: we generally prepped earlier for multiple expansions, but for now we only use one.
top_keyword_chunks_thread = run_in_background(
document_index.hybrid_retrieval,
query.expanded_queries.keywords_expansions[0],
keyword_embeddings[0],
query.processed_keywords,
query.filters,
HYBRID_ALPHA_KEYWORD,
query.recency_bias_multiplier,
query.num_hits,
QueryExpansionType.KEYWORD,
query.offset,
)
if query.search_type == SearchType.SEMANTIC:
assert semantic_embeddings is not None
top_semantic_chunks_thread = run_in_background(
document_index.hybrid_retrieval,
query.expanded_queries.semantic_expansions[0],
semantic_embeddings[0],
query.processed_keywords,
query.filters,
HYBRID_ALPHA,
query.recency_bias_multiplier,
query.num_hits,
QueryExpansionType.SEMANTIC,
query.offset,
)
top_base_chunks = wait_on_background(top_base_chunks_thread)
top_keyword_chunks = wait_on_background(top_keyword_chunks_thread)
if query.search_type == SearchType.SEMANTIC:
assert top_semantic_chunks_thread is not None
top_semantic_chunks = wait_on_background(top_semantic_chunks_thread)
all_top_chunks = top_base_chunks + top_keyword_chunks
# use all three retrieval methods to retrieve top chunks
if query.search_type == SearchType.SEMANTIC and top_semantic_chunks is not None:
all_top_chunks += top_semantic_chunks
top_chunks = _dedupe_chunks(all_top_chunks)
else:
top_base_chunks = wait_on_background(top_base_chunks_thread)
top_chunks = _dedupe_chunks(top_base_chunks)
retrieval_requests: list[VespaChunkRequest] = []
normal_chunks: list[InferenceChunkUncleaned] = []
referenced_chunk_scores: dict[tuple[str, int], float] = {}
@@ -349,8 +233,6 @@ def retrieve_chunks(
) -> list[InferenceChunk]:
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
logger.info(f"RETRIEVAL CHUNKS query: {query}")
multilingual_expansion = get_multilingual_expansion(db_session)
# Don't do query expansion on complex queries, rephrasings likely would not work well
if not multilingual_expansion or "\n" in query.query or "\r" in query.query:

View File

@@ -1,10 +1,6 @@
import string
from collections.abc import Sequence
from typing import TypeVar
from nltk.corpus import stopwords # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from onyx.chat.models import SectionRelevancePiece
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
@@ -140,19 +136,3 @@ def chunks_or_sections_to_search_docs(
]
return search_docs
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
try:
# Re-tokenize using the NLTK tokenizer for better matching
query = " ".join(keywords)
stop_words = set(stopwords.words("english"))
word_tokens = word_tokenize(query)
text_trimmed = [
word
for word in word_tokens
if (word.casefold() not in stop_words and word not in string.punctuation)
]
return text_trimmed or word_tokens
except Exception:
return keywords

View File

@@ -56,8 +56,7 @@ def get_total_users_count(db_session: Session) -> int:
async def get_user_count(only_admin_users: bool = False) -> int:
async with get_async_session_with_tenant() as session:
count_stmt = func.count(User.id) # type: ignore
stmt = select(count_stmt)
stmt = select(func.count(User.id))
if only_admin_users:
stmt = stmt.where(User.role == UserRole.ADMIN)
result = await session.execute(stmt)

View File

@@ -334,21 +334,3 @@ def mark_ccpair_with_indexing_trigger(
except Exception:
db_session.rollback()
raise
def get_unprocessed_connector_ids(db_session: Session) -> list[int]:
"""
Retrieves a list of connector IDs that have not been KG processed for a given tenant.
Args:
db_session (Session): The database session to use
Returns:
list[int]: List of connector IDs that have enabled KG extraction but have unprocessed documents
"""
try:
stmt = select(Connector.id).where(Connector.kg_extraction_enabled)
result = db_session.execute(stmt)
return [row[0] for row in result.fetchall()]
except Exception as e:
logger.error(f"Error fetching unprocessed connector IDs: {str(e)}")
raise e

View File

@@ -7,7 +7,6 @@ from sqlalchemy import desc
from sqlalchemy import exists
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
@@ -395,20 +394,6 @@ def update_connector_credential_pair(
)
def set_cc_pair_repeated_error_state(
db_session: Session,
cc_pair_id: int,
in_repeated_error_state: bool,
) -> None:
stmt = (
update(ConnectorCredentialPair)
.where(ConnectorCredentialPair.id == cc_pair_id)
.values(in_repeated_error_state=in_repeated_error_state)
)
db_session.execute(stmt)
db_session.commit()
def delete_connector_credential_pair__no_commit(
db_session: Session,
connector_id: int,
@@ -472,7 +457,7 @@ def add_credential_to_connector(
access_type: AccessType,
groups: list[int] | None,
auto_sync_options: dict | None = None,
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.SCHEDULED,
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.ACTIVE,
last_successful_index_time: datetime | None = None,
seeding_flow: bool = False,
is_user_file: bool = False,

View File

@@ -23,8 +23,6 @@ from sqlalchemy.sql.expression import null
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
@@ -379,7 +377,6 @@ def upsert_documents(
last_modified=datetime.now(timezone.utc),
primary_owners=doc.primary_owners,
secondary_owners=doc.secondary_owners,
kg_processed=False,
)
)
for doc in seen_documents.values()
@@ -846,213 +843,3 @@ def fetch_chunk_count_for_document(
) -> int | None:
stmt = select(DbDocument.chunk_count).where(DbDocument.id == document_id)
return db_session.execute(stmt).scalar_one_or_none()
def get_unprocessed_kg_documents_for_connector(
db_session: Session,
connector_id: int,
batch_size: int = 100,
) -> Generator[DbDocument, None, None]:
"""
Retrieves all documents associated with a connector that have not yet been processed
for knowledge graph extraction. Uses a generator pattern to handle large result sets.
Args:
db_session (Session): The database session to use
connector_id (int): The ID of the connector to check
batch_size (int): Number of documents to fetch per batch, defaults to 100
Yields:
DbDocument: Documents that haven't been KG processed, one at a time
"""
offset = 0
while True:
stmt = (
select(DbDocument)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
or_(
DocumentByConnectorCredentialPair.has_been_kg_processed.is_(
None
),
DocumentByConnectorCredentialPair.has_been_kg_processed.is_(
False
),
),
)
)
.distinct()
.limit(batch_size)
.offset(offset)
)
batch = list(db_session.scalars(stmt).all())
if not batch:
break
for document in batch:
yield document
offset += batch_size
def get_kg_processed_document_ids(db_session: Session) -> list[str]:
"""
Retrieves all document IDs where kg_processed is True.
Args:
db_session (Session): The database session to use
Returns:
list[str]: List of document IDs that have been KG processed
"""
stmt = select(DbDocument.id).where(DbDocument.kg_processed.is_(True))
return list(db_session.scalars(stmt).all())
def update_document_kg_info(
db_session: Session,
document_id: str,
kg_processed: bool,
kg_data: dict,
) -> None:
"""Updates the knowledge graph related information for a document.
Args:
db_session (Session): The database session to use
document_id (str): The ID of the document to update
kg_processed (bool): Whether the document has been processed for KG extraction
kg_data (dict): Dictionary containing KG data with 'entities', 'relationships', and 'terms' keys
Raises:
ValueError: If the document with the given ID is not found
"""
stmt = (
update(DbDocument)
.where(DbDocument.id == document_id)
.values(
kg_processed=kg_processed,
kg_data=kg_data,
)
)
db_session.execute(stmt)
def get_document_kg_info(
db_session: Session,
document_id: str,
) -> tuple[bool, dict] | None:
"""Retrieves the knowledge graph processing status and data for a document.
Args:
db_session (Session): The database session to use
document_id (str): The ID of the document to query
Returns:
Optional[Tuple[bool, dict]]: A tuple containing:
- bool: Whether the document has been KG processed
- dict: The KG data containing 'entities', 'relationships', and 'terms'
Returns None if the document is not found
"""
stmt = select(DbDocument.kg_processed, DbDocument.kg_data).where(
DbDocument.id == document_id
)
result = db_session.execute(stmt).one_or_none()
if result is None:
return None
return result.kg_processed, result.kg_data or {}
def get_all_kg_processed_documents_info(
db_session: Session,
) -> list[tuple[str, dict]]:
"""Retrieves the knowledge graph data for all documents that have been processed.
Args:
db_session (Session): The database session to use
Returns:
List[Tuple[str, dict]]: A list of tuples containing:
- str: The document ID
- dict: The KG data containing 'entities', 'relationships', and 'terms'
Only returns documents where kg_processed is True
"""
stmt = (
select(DbDocument.id, DbDocument.kg_data)
.where(DbDocument.kg_processed.is_(True))
.order_by(DbDocument.id)
)
results = db_session.execute(stmt).all()
return [(str(doc_id), kg_data or {}) for doc_id, kg_data in results]
def get_base_llm_doc_information(
db_session: Session, document_ids: list[str]
) -> list[InferenceSection]:
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
results = db_session.execute(stmt).all()
inference_sections = []
for doc in results:
bare_doc = doc[0]
inference_section = InferenceSection(
center_chunk=InferenceChunk(
document_id=bare_doc.id,
chunk_id=0,
source_type=DocumentSource.NOT_APPLICABLE,
semantic_identifier=bare_doc.semantic_id,
title=None,
boost=0,
recency_bias=0,
score=0,
hidden=False,
metadata={},
blurb="",
content="",
source_links=None,
image_file_name=None,
section_continuation=False,
match_highlights=[],
doc_summary="",
chunk_context="",
updated_at=None,
),
chunks=[],
combined_content="",
)
inference_sections.append(inference_section)
return inference_sections
def get_document_updated_at(
document_id: str,
db_session: Session,
) -> datetime | None:
"""Retrieves the doc_updated_at timestamp for a given document ID.
Args:
document_id (str): The ID of the document to query
db_session (Session): The database session to use
Returns:
Optional[datetime]: The doc_updated_at timestamp if found, None if document doesn't exist
"""
if len(document_id.split(":")) == 2:
document_id = document_id.split(":")[1]
elif len(document_id.split(":")) > 2:
raise ValueError(f"Invalid document ID: {document_id}")
else:
pass
stmt = select(DbDocument.doc_updated_at).where(DbDocument.id == document_id)
return db_session.execute(stmt).scalar_one_or_none()

View File

@@ -1,271 +0,0 @@
from datetime import datetime
from typing import List
from sqlalchemy import literal_column
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session
from onyx.db.models import KGEntity
from onyx.db.models import KGEntityType
def get_entity_types(
db_session: Session,
active: bool | None = True,
) -> list[KGEntityType]:
# Query the database for all distinct entity types
if active is None:
return db_session.query(KGEntityType).order_by(KGEntityType.id_name).all()
else:
return (
db_session.query(KGEntityType)
.filter(KGEntityType.active == active)
.order_by(KGEntityType.id_name)
.all()
)
def add_entity(
db_session: Session,
entity_type: str,
name: str,
document_id: str | None = None,
cluster_count: int = 0,
event_time: datetime | None = None,
) -> "KGEntity | None":
"""Add a new entity to the database.
Args:
db_session: SQLAlchemy session
entity_type: Type of the entity (must match an existing KGEntityType)
name: Name of the entity
cluster_count: Number of clusters this entity has been found
Returns:
KGEntity: The created entity
"""
entity_type = entity_type.upper()
name = name.title()
id_name = f"{entity_type}:{name}"
# Create new entity
stmt = (
pg_insert(KGEntity)
.values(
id_name=id_name,
entity_type_id_name=entity_type,
document_id=document_id,
name=name,
cluster_count=cluster_count,
event_time=event_time,
)
.on_conflict_do_update(
index_elements=["id_name"],
set_=dict(
# Direct numeric addition without text()
cluster_count=KGEntity.cluster_count
+ literal_column("EXCLUDED.cluster_count"),
# Keep other fields updated as before
entity_type_id_name=entity_type,
document_id=document_id,
name=name,
event_time=event_time,
),
)
.returning(KGEntity)
)
result = db_session.execute(stmt).scalar()
return result
def get_kg_entity_by_document(db: Session, document_id: str) -> KGEntity | None:
"""
Check if a document_id exists in the kg_entities table and return its id_name if found.
Args:
db: SQLAlchemy database session
document_id: The document ID to search for
Returns:
The id_name of the matching KGEntity if found, None otherwise
"""
query = select(KGEntity).where(KGEntity.document_id == document_id)
result = db.execute(query).scalar()
return result
def get_ungrounded_entities(db_session: Session) -> List[KGEntity]:
"""Get all entities whose entity type has grounding = 'UE' (ungrounded entities).
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntity objects belonging to ungrounded entity types
"""
return (
db_session.query(KGEntity)
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
.filter(KGEntityType.grounding == "UE")
.all()
)
def get_grounded_entities(db_session: Session) -> List[KGEntity]:
"""Get all entities whose entity type has grounding = 'UE' (ungrounded entities).
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntity objects belonging to ungrounded entity types
"""
return (
db_session.query(KGEntity)
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
.filter(KGEntityType.grounding == "GE")
.all()
)
def get_determined_grounded_entity_types(db_session: Session) -> List[KGEntityType]:
"""Get all entity types that have non-null ge_determine_instructions.
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntityType objects that have ge_determine_instructions defined
"""
return (
db_session.query(KGEntityType)
.filter(KGEntityType.ge_determine_instructions.isnot(None))
.all()
)
def get_entity_types_with_grounded_source_name(
db_session: Session,
) -> List[KGEntityType]:
"""Get all entity types that have non-null grounded_source_name.
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntityType objects that have grounded_source_name defined
"""
return (
db_session.query(KGEntityType)
.filter(KGEntityType.grounded_source_name.isnot(None))
.all()
)
def get_entity_types_with_grounding_signature(
db_session: Session,
) -> List[KGEntityType]:
"""Get all entity types that have non-null ge_grounding_signature.
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntityType objects that have ge_grounding_signature defined
"""
return (
db_session.query(KGEntityType)
.filter(KGEntityType.ge_grounding_signature.isnot(None))
.all()
)
def get_ge_entities_by_types(
db_session: Session, entity_types: List[str]
) -> List[KGEntity]:
"""Get all entities matching an entity_type.
Args:
db_session: SQLAlchemy session
entity_types: List of entity types to filter by
Returns:
List of KGEntity objects belonging to the specified entity types
"""
return (
db_session.query(KGEntity)
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
.filter(KGEntity.entity_type_id_name.in_(entity_types))
.filter(KGEntityType.grounding == "GE")
.all()
)
def delete_entities_by_id_names(db_session: Session, id_names: list[str]) -> int:
"""
Delete entities from the database based on a list of id_names.
Args:
db_session: SQLAlchemy database session
id_names: List of entity id_names to delete
Returns:
Number of entities deleted
"""
deleted_count = (
db_session.query(KGEntity)
.filter(KGEntity.id_name.in_(id_names))
.delete(synchronize_session=False)
)
db_session.flush() # Flush to ensure deletion is processed
return deleted_count
def get_entities_for_types(
db_session: Session, entity_types: List[str]
) -> List[KGEntity]:
"""Get all entities that belong to the specified entity types.
Args:
db_session: SQLAlchemy session
entity_types: List of entity type id_names to filter by
Returns:
List of KGEntity objects belonging to the specified entity types
"""
return (
db_session.query(KGEntity)
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
.filter(KGEntity.entity_type_id_name.in_(entity_types))
.all()
)
def get_entity_type_by_grounded_source_name(
db_session: Session, grounded_source_name: str
) -> KGEntityType | None:
"""Get an entity type by its grounded_source_name and return it as a dictionary.
Args:
db_session: SQLAlchemy session
grounded_source_name: The grounded_source_name of the entity to retrieve
Returns:
Dictionary containing the entity's data with column names as keys,
or None if the entity is not found
"""
entity_type = (
db_session.query(KGEntityType)
.filter(KGEntityType.grounded_source_name == grounded_source_name)
.first()
)
if entity_type is None:
return None
return entity_type

View File

@@ -18,12 +18,6 @@ class IndexingStatus(str, PyEnum):
}
return self in terminal_states
def is_successful(self) -> bool:
return (
self == IndexingStatus.SUCCESS
or self == IndexingStatus.COMPLETED_WITH_ERRORS
)
class IndexingMode(str, PyEnum):
UPDATE = "update"
@@ -79,19 +73,13 @@ class ChatSessionSharedStatus(str, PyEnum):
class ConnectorCredentialPairStatus(str, PyEnum):
SCHEDULED = "SCHEDULED"
INITIAL_INDEXING = "INITIAL_INDEXING"
ACTIVE = "ACTIVE"
PAUSED = "PAUSED"
DELETING = "DELETING"
INVALID = "INVALID"
def is_active(self) -> bool:
return (
self == ConnectorCredentialPairStatus.ACTIVE
or self == ConnectorCredentialPairStatus.SCHEDULED
or self == ConnectorCredentialPairStatus.INITIAL_INDEXING
)
return self == ConnectorCredentialPairStatus.ACTIVE
class AccessType(str, PyEnum):

View File

@@ -59,7 +59,6 @@ def get_recent_completed_attempts_for_cc_pair(
limit: int,
db_session: Session,
) -> list[IndexAttempt]:
"""Most recent to least recent."""
return (
db_session.query(IndexAttempt)
.filter(
@@ -75,25 +74,6 @@ def get_recent_completed_attempts_for_cc_pair(
)
def get_recent_attempts_for_cc_pair(
cc_pair_id: int,
search_settings_id: int,
limit: int,
db_session: Session,
) -> list[IndexAttempt]:
"""Most recent to least recent."""
return (
db_session.query(IndexAttempt)
.filter(
IndexAttempt.connector_credential_pair_id == cc_pair_id,
IndexAttempt.search_settings_id == search_settings_id,
)
.order_by(IndexAttempt.time_updated.desc())
.limit(limit)
.all()
)
def get_index_attempt(
db_session: Session, index_attempt_id: int
) -> IndexAttempt | None:

View File

@@ -53,7 +53,6 @@ from onyx.db.enums import (
SyncType,
SyncStatus,
)
from onyx.configs.constants import NotificationType
from onyx.configs.constants import SearchFeedbackType
from onyx.configs.constants import TokenRateLimitScope
@@ -438,9 +437,6 @@ class ConnectorCredentialPair(Base):
status: Mapped[ConnectorCredentialPairStatus] = mapped_column(
Enum(ConnectorCredentialPairStatus, native_enum=False), nullable=False
)
# this is separate from the `status` above, since a connector can be `INITIAL_INDEXING`, `ACTIVE`,
# or `PAUSED` and still be in a repeated error state.
in_repeated_error_state: Mapped[bool] = mapped_column(Boolean, default=False)
connector_id: Mapped[int] = mapped_column(
ForeignKey("connector.id"), primary_key=True
)
@@ -587,22 +583,6 @@ class Document(Base):
)
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
# tables for the knowledge graph data
kg_processed: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
comment="Whether this document has been processed for knowledge graph extraction",
)
kg_data: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Knowledge graph data extracted from this document",
)
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
"DocumentRetrievalFeedback", back_populates="document"
)
@@ -621,304 +601,6 @@ class Document(Base):
)
class KGEntityType(Base):
__tablename__ = "kg_entity_type"
# Primary identifier
id_name: Mapped[str] = mapped_column(
String, primary_key=True, nullable=False, index=True
)
description: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
grounding: Mapped[str] = mapped_column(
NullFilteredString, nullable=False, index=False
)
grounded_source_name: Mapped[str] = mapped_column(
NullFilteredString, nullable=False, index=False
)
ge_determine_instructions: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=True, default=None
)
ge_grounding_signature: Mapped[str] = mapped_column(
NullFilteredString, nullable=True, index=False, default=None
)
clustering: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Clustering information for this entity type",
)
classification_requirements: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Pre-extraction classification requirements and instructions",
)
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
extraction_sources: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
comment="Sources and methods used to extract this entity",
)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
class KGRelationshipType(Base):
__tablename__ = "kg_relationship_type"
# Primary identifier
id_name: Mapped[str] = mapped_column(
NullFilteredString,
primary_key=True,
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
source_entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
target_entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
definition: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
comment="Whether this relationship type represents a definition",
)
clustering: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Clustering information for this relationship type",
)
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
# Relationships to EntityType
source_type: Mapped["KGEntityType"] = relationship(
"KGEntityType",
foreign_keys=[source_entity_type_id_name],
backref="source_relationship_type",
)
target_type: Mapped["KGEntityType"] = relationship(
"KGEntityType",
foreign_keys=[target_entity_type_id_name],
backref="target_relationship_type",
)
class KGEntity(Base):
__tablename__ = "kg_entity"
# Primary identifier
id_name: Mapped[str] = mapped_column(
NullFilteredString, primary_key=True, index=True
)
# Basic entity information
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
document_id: Mapped[str | None] = mapped_column(
NullFilteredString, nullable=True, index=True
)
alternative_names: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Reference to KGEntityType
entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
# Relationship to KGEntityType
entity_type: Mapped["KGEntityType"] = relationship("KGEntityType", backref="entity")
description: Mapped[str | None] = mapped_column(String, nullable=True)
keywords: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Access control
acl: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Boosts - using JSON for flexibility
boosts: Mapped[dict] = mapped_column(postgresql.JSONB, nullable=False, default=dict)
event_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
comment="Time of the event being processed",
)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
__table_args__ = (
# Fixed column names in indexes
Index("ix_entity_type_acl", entity_type_id_name, acl),
Index("ix_entity_name_search", name, entity_type_id_name),
)
class KGRelationship(Base):
__tablename__ = "kg_relationship"
# Primary identifier
id_name: Mapped[str] = mapped_column(
NullFilteredString, primary_key=True, index=True
)
# Source and target nodes (foreign keys to Entity table)
source_node: Mapped[str] = mapped_column(
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
)
target_node: Mapped[str] = mapped_column(
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
)
# Relationship type
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
# Add new relationship type reference
relationship_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_relationship_type.id_name"),
nullable=False,
index=True,
)
# Add the SQLAlchemy relationship property
relationship_type: Mapped["KGRelationshipType"] = relationship(
"KGRelationshipType", backref="relationship"
)
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
# Relationships to Entity table
source: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[source_node])
target: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[target_node])
__table_args__ = (
# Index for querying relationships by type
Index("ix_kg_relationship_type", type),
# Composite index for source/target queries
Index("ix_kg_relationship_nodes", source_node, target_node),
# Ensure unique relationships between nodes of a specific type
UniqueConstraint(
"source_node",
"target_node",
"type",
name="uq_kg_relationship_source_target_type",
),
)
class KGTerm(Base):
__tablename__ = "kg_term"
# Make id_term the primary key
id_term: Mapped[str] = mapped_column(
NullFilteredString, primary_key=True, nullable=False, index=True
)
# List of entity types this term applies to
entity_types: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
__table_args__ = (
# Index for searching terms with specific entity types
Index("ix_search_term_entities", entity_types),
# Index for term lookups
Index("ix_search_term_term", id_term),
)
class ChunkStats(Base):
__tablename__ = "chunk_stats"
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
@@ -1006,14 +688,6 @@ class Connector(Base):
indexing_start: Mapped[datetime.datetime | None] = mapped_column(
DateTime, nullable=True
)
kg_extraction_enabled: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
comment="Whether this connector should extract knowledge graph entities",
)
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
time_created: Mapped[datetime.datetime] = mapped_column(
@@ -1457,12 +1131,6 @@ class DocumentByConnectorCredentialPair(Base):
# the actual indexing is complete
has_been_indexed: Mapped[bool] = mapped_column(Boolean)
has_been_kg_processed: Mapped[bool | None] = mapped_column(
Boolean,
nullable=True,
comment="Whether this document has been processed for knowledge graph extraction",
)
connector: Mapped[Connector] = relationship(
"Connector", back_populates="documents_by_connector", passive_deletes=True
)

View File

@@ -1,298 +0,0 @@
from typing import List
from sqlalchemy import or_
from sqlalchemy.orm import Session
from onyx.db.models import KGRelationship
from onyx.db.models import KGRelationshipType
from onyx.kg.utils.formatting_utils import format_entity
from onyx.kg.utils.formatting_utils import format_relationship
from onyx.kg.utils.formatting_utils import generate_relationship_type
def add_relationship(
db_session: Session,
relationship_id_name: str,
cluster_count: int | None = None,
) -> "KGRelationship":
"""
Add a relationship between two entities to the database.
Args:
db_session: SQLAlchemy database session
source_entity_id: ID of the source entity
relationship_type: Type of relationship
target_entity_id: ID of the target entity
cluster_count: Optional count of similar relationships clustered together
Returns:
The created KGRelationship object
Raises:
sqlalchemy.exc.IntegrityError: If the relationship already exists or entities don't exist
"""
# Generate a unique ID for the relationship
(
source_entity_id_name,
relationship_string,
target_entity_id_name,
) = relationship_id_name.split("__")
source_entity_id_name = format_entity(source_entity_id_name)
target_entity_id_name = format_entity(target_entity_id_name)
relationship_id_name = format_relationship(relationship_id_name)
relationship_type = generate_relationship_type(relationship_id_name)
# Create new relationship
relationship = KGRelationship(
id_name=relationship_id_name,
source_node=source_entity_id_name,
target_node=target_entity_id_name,
type=relationship_string.lower(),
relationship_type_id_name=relationship_type,
cluster_count=cluster_count,
)
db_session.add(relationship)
db_session.flush() # Flush to get any DB errors early
return relationship
def add_or_increment_relationship(
db_session: Session,
relationship_id_name: str,
) -> "KGRelationship":
"""
Add a relationship between two entities to the database if it doesn't exist,
or increment its cluster_count by 1 if it already exists.
Args:
db_session: SQLAlchemy database session
relationship_id_name: The ID name of the relationship in format "source__relationship__target"
Returns:
The created or updated KGRelationship object
Raises:
sqlalchemy.exc.IntegrityError: If there's an error with the database operation
"""
# Format the relationship_id_name
relationship_id_name = format_relationship(relationship_id_name)
# Check if the relationship already exists
existing_relationship = (
db_session.query(KGRelationship)
.filter(KGRelationship.id_name == relationship_id_name)
.first()
)
if existing_relationship:
# If it exists, increment the cluster_count
existing_relationship.cluster_count = (
existing_relationship.cluster_count or 0
) + 1
db_session.flush()
return existing_relationship
else:
# If it doesn't exist, add it with cluster_count=1
return add_relationship(db_session, relationship_id_name, cluster_count=1)
def add_relationship_type(
db_session: Session,
source_entity_type: str,
relationship_type: str,
target_entity_type: str,
definition: bool = False,
extraction_count: int = 0,
) -> "KGRelationshipType":
"""
Add a new relationship type to the database.
Args:
db_session: SQLAlchemy session
source_entity_type: Type of the source entity
relationship_type: Type of relationship
target_entity_type: Type of the target entity
definition: Whether this relationship type represents a definition (default False)
Returns:
The created KGRelationshipType object
Raises:
sqlalchemy.exc.IntegrityError: If the relationship type already exists
"""
id_name = f"{source_entity_type.upper()}__{relationship_type}__{target_entity_type.upper()}"
# Create new relationship type
rel_type = KGRelationshipType(
id_name=id_name,
name=relationship_type,
source_entity_type_id_name=source_entity_type.upper(),
target_entity_type_id_name=target_entity_type.upper(),
definition=definition,
cluster_count=extraction_count,
type=relationship_type, # Using the relationship_type as the type
active=True, # Setting as active by default
)
db_session.add(rel_type)
db_session.flush() # Flush to get any DB errors early
return rel_type
def get_all_relationship_types(db_session: Session) -> list["KGRelationshipType"]:
"""
Retrieve all relationship types from the database.
Args:
db_session: SQLAlchemy database session
Returns:
List of KGRelationshipType objects
"""
return db_session.query(KGRelationshipType).all()
def get_all_relationships(db_session: Session) -> list["KGRelationship"]:
"""
Retrieve all relationships from the database.
Args:
db_session: SQLAlchemy database session
Returns:
List of KGRelationship objects
"""
return db_session.query(KGRelationship).all()
def delete_relationships_by_id_names(db_session: Session, id_names: list[str]) -> int:
"""
Delete relationships from the database based on a list of id_names.
Args:
db_session: SQLAlchemy database session
id_names: List of relationship id_names to delete
Returns:
Number of relationships deleted
Raises:
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
"""
deleted_count = (
db_session.query(KGRelationship)
.filter(KGRelationship.id_name.in_(id_names))
.delete(synchronize_session=False)
)
db_session.flush() # Flush to ensure deletion is processed
return deleted_count
def delete_relationship_types_by_id_names(
db_session: Session, id_names: list[str]
) -> int:
"""
Delete relationship types from the database based on a list of id_names.
Args:
db_session: SQLAlchemy database session
id_names: List of relationship type id_names to delete
Returns:
Number of relationship types deleted
Raises:
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
"""
deleted_count = (
db_session.query(KGRelationshipType)
.filter(KGRelationshipType.id_name.in_(id_names))
.delete(synchronize_session=False)
)
db_session.flush() # Flush to ensure deletion is processed
return deleted_count
def get_relationships_for_entity_type_pairs(
db_session: Session, entity_type_pairs: list[tuple[str, str]]
) -> list["KGRelationshipType"]:
"""
Get relationship types from the database based on a list of entity type pairs.
Args:
db_session: SQLAlchemy database session
entity_type_pairs: List of tuples where each tuple contains (source_entity_type, target_entity_type)
Returns:
List of KGRelationshipType objects where source and target types match the provided pairs
"""
conditions = [
(
(KGRelationshipType.source_entity_type_id_name == source_type)
& (KGRelationshipType.target_entity_type_id_name == target_type)
)
for source_type, target_type in entity_type_pairs
]
return db_session.query(KGRelationshipType).filter(or_(*conditions)).all()
def get_allowed_relationship_type_pairs(
db_session: Session, entities: list[str]
) -> list[str]:
"""
Get the allowed relationship pairs for the given entities.
Args:
db_session: SQLAlchemy database session
entities: List of entity type ID names to filter by
Returns:
List of id_names from KGRelationshipType where both source and target entity types
are in the provided entities list
"""
entity_types = list(set([entity.split(":")[0] for entity in entities]))
return [
row[0]
for row in (
db_session.query(KGRelationshipType.id_name)
.filter(KGRelationshipType.source_entity_type_id_name.in_(entity_types))
.filter(KGRelationshipType.target_entity_type_id_name.in_(entity_types))
.distinct()
.all()
)
]
def get_relationships_of_entity(db_session: Session, entity_id: str) -> List[str]:
"""Get all relationship ID names where the given entity is either the source or target node.
Args:
db_session: SQLAlchemy session
entity_id: ID of the entity to find relationships for
Returns:
List of relationship ID names where the entity is either source or target
"""
return [
row[0]
for row in (
db_session.query(KGRelationship.id_name)
.filter(
or_(
KGRelationship.source_node == entity_id,
KGRelationship.target_node == entity_id,
)
)
.all()
)
]

View File

@@ -4,8 +4,6 @@ from datetime import datetime
from typing import Any
from onyx.access.models import DocumentAccess
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.db.enums import EmbeddingPrecision
@@ -98,7 +96,7 @@ class VespaDocumentFields:
understandable like this for now.
"""
# all other fields except these 4 and knowledge graph will always be left alone by the update request
# all other fields except these 4 will always be left alone by the update request
access: DocumentAccess | None = None
document_sets: set[str] | None = None
boost: float | None = None
@@ -353,9 +351,7 @@ class HybridCapable(abc.ABC):
hybrid_alpha: float,
time_decay_multiplier: float,
num_to_retrieve: int,
ranking_profile_type: QueryExpansionType,
offset: int = 0,
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
) -> list[InferenceChunkUncleaned]:
"""
Run hybrid search and return a list of inference chunks.

View File

@@ -85,25 +85,6 @@ schema DANSWER_CHUNK_NAME {
indexing: attribute
}
# Separate array fields for knowledge graph data
field kg_entities type weightedset<string> {
rank: filter
indexing: summary | attribute
attribute: fast-search
}
field kg_relationships type weightedset<string> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
field kg_terms type weightedset<string> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
# Needs to have a separate Attribute list for efficient filtering
field metadata_list type array<string> {
indexing: summary | attribute
@@ -195,7 +176,7 @@ schema DANSWER_CHUNK_NAME {
match-features: recency_bias
}
rank-profile hybrid_search_semantic_base_VARIABLE_DIM inherits default, default_rank {
rank-profile hybrid_searchVARIABLE_DIM inherits default, default_rank {
inputs {
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
}
@@ -211,75 +192,7 @@ schema DANSWER_CHUNK_NAME {
# First phase must be vector to allow hits that have no keyword matches
first-phase {
expression: query(title_content_ratio) * closeness(field, title_embedding) + (1 - query(title_content_ratio)) * closeness(field, embeddings)
}
# Weighted average between Vector Search and BM-25
global-phase {
expression {
(
# Weighted Vector Similarity Score
(
query(alpha) * (
(query(title_content_ratio) * normalize_linear(title_vector_score))
+
((1 - query(title_content_ratio)) * normalize_linear(closeness(field, embeddings)))
)
)
+
# Weighted Keyword Similarity Score
# Note: for the BM25 Title score, it requires decent stopword removal in the query
# This needs to be the case so there aren't irrelevant titles being normalized to a score of 1
(
(1 - query(alpha)) * (
(query(title_content_ratio) * normalize_linear(bm25(title)))
+
((1 - query(title_content_ratio)) * normalize_linear(bm25(content)))
)
)
)
# Boost based on user feedback
* document_boost
# Decay factor based on time document was last updated
* recency_bias
# Boost based on aggregated boost calculation
* aggregated_chunk_boost
}
rerank-count: 1000
}
match-features {
bm25(title)
bm25(content)
closeness(field, title_embedding)
closeness(field, embeddings)
document_boost
recency_bias
aggregated_chunk_boost
closest(embeddings)
}
}
rank-profile hybrid_search_keyword_base_VARIABLE_DIM inherits default, default_rank {
inputs {
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
}
function title_vector_score() {
expression {
# If no good matching titles, then it should use the context embeddings rather than having some
# irrelevant title have a vector score of 1. This way at least it will be the doc with the highest
# matching content score getting the full score
max(closeness(field, embeddings), closeness(field, title_embedding))
}
}
# First phase must be vector to allow hits that have no keyword matches
first-phase {
expression: query(title_content_ratio) * bm25(title) + (1 - query(title_content_ratio)) * bm25(content)
expression: closeness(field, embeddings)
}
# Weighted average between Vector Search and BM-25

View File

@@ -166,19 +166,18 @@ def _get_chunks_via_visit_api(
# build the list of fields to retrieve
field_set_list = (
[] if not field_names else [f"{field_name}" for field_name in field_names]
None
if not field_names
else [f"{index_name}:{field_name}" for field_name in field_names]
)
acl_fieldset_entry = f"{ACCESS_CONTROL_LIST}"
acl_fieldset_entry = f"{index_name}:{ACCESS_CONTROL_LIST}"
if (
field_set_list
and filters.access_control_list
and acl_fieldset_entry not in field_set_list
):
field_set_list.append(acl_fieldset_entry)
if field_set_list:
field_set = f"{index_name}:" + ",".join(field_set_list)
else:
field_set = None
field_set = ",".join(field_set_list) if field_set_list else None
# build filters
selection = f"{index_name}.document_id=='{chunk_request.document_id}'"

View File

@@ -17,10 +17,8 @@ from uuid import UUID
import httpx # type: ignore
import requests # type: ignore
from pydantic import BaseModel
from retry import retry
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -30,7 +28,6 @@ from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.document_index_utils import get_document_chunk_ids
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentInsertionRecord
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
@@ -102,37 +99,6 @@ class _VespaUpdateRequest:
update_request: dict[str, dict]
class KGVespaChunkUpdateRequest(BaseModel):
document_id: str
chunk_id: int
url: str
update_request: dict[str, dict]
class KGUChunkUpdateRequest(BaseModel):
"""
Update KG fields for a document
"""
document_id: str
chunk_id: int
core_entity: str
entities: set[str] | None = None
relationships: set[str] | None = None
terms: set[str] | None = None
class KGUDocumentUpdateRequest(BaseModel):
"""
Update KG fields for a document
"""
document_id: str
entities: set[str]
relationships: set[str]
terms: set[str]
def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
@@ -537,51 +503,6 @@ class VespaIndex(DocumentIndex):
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
raise requests.HTTPError(failure_msg) from e
@classmethod
def _apply_kg_chunk_updates_batched(
cls,
updates: list[KGVespaChunkUpdateRequest],
httpx_client: httpx.Client,
batch_size: int = BATCH_SIZE,
) -> None:
"""Runs a batch of updates in parallel via the ThreadPoolExecutor."""
def _kg_update_chunk(
update: KGVespaChunkUpdateRequest, http_client: httpx.Client
) -> httpx.Response:
# logger.debug(
# f"Updating KG with request to {update.url} with body {update.update_request}"
# )
return http_client.put(
update.url,
headers={"Content-Type": "application/json"},
json=update.update_request,
)
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficient for
# indexing / updates / deletes since we have to make a large volume of requests.
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
httpx_client as http_client,
):
for update_batch in batch_generator(updates, batch_size):
future_to_document_id = {
executor.submit(
_kg_update_chunk,
update,
http_client,
): update.document_id
for update in update_batch
}
for future in concurrent.futures.as_completed(future_to_document_id):
res = future.result()
try:
res.raise_for_status()
except requests.HTTPError as e:
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
raise requests.HTTPError(failure_msg) from e
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
@@ -665,89 +586,6 @@ class VespaIndex(DocumentIndex):
time.monotonic() - update_start,
)
def kg_chunk_updates(
self, kg_update_requests: list[KGUChunkUpdateRequest], tenant_id: str
) -> None:
def _get_general_entity(specific_entity: str) -> str:
entity_type, entity_name = specific_entity.split(":")
if entity_type != "*":
return f"{entity_type}:*"
else:
return specific_entity
processed_updates_requests: list[KGVespaChunkUpdateRequest] = []
logger.debug(f"Updating {len(kg_update_requests)} documents in Vespa")
update_start = time.monotonic()
# Build the _VespaUpdateRequest objects
for kg_update_request in kg_update_requests:
kg_update_dict: dict[str, dict] = {"fields": {}}
implied_entities = set()
if kg_update_request.relationships is not None:
for kg_relationship in kg_update_request.relationships:
kg_relationship_split = kg_relationship.split("__")
if len(kg_relationship_split) == 3:
implied_entities.add(kg_relationship_split[0])
implied_entities.add(kg_relationship_split[2])
# Keep this for now in case we want to also add the general entities
# implied_entities.add(_get_general_entity(kg_relationship_split[0]))
# implied_entities.add(_get_general_entity(kg_relationship_split[2]))
kg_update_dict["fields"]["kg_relationships"] = {
"assign": {
kg_relationship: 1
for kg_relationship in kg_update_request.relationships
}
}
if kg_update_request.entities is not None or implied_entities:
if kg_update_request.entities is None:
kg_entities = implied_entities
else:
kg_entities = set(kg_update_request.entities)
kg_entities.update(implied_entities)
kg_update_dict["fields"]["kg_entities"] = {
"assign": {kg_entity: 1 for kg_entity in kg_entities}
}
if kg_update_request.terms is not None:
kg_update_dict["fields"]["kg_terms"] = {
"assign": {kg_term: 1 for kg_term in kg_update_request.terms}
}
if not kg_update_dict["fields"]:
logger.error("Update request received but nothing to update")
continue
doc_chunk_id = get_uuid_from_chunk_info(
document_id=kg_update_request.document_id,
chunk_id=kg_update_request.chunk_id,
tenant_id=tenant_id,
large_chunk_id=None,
)
processed_updates_requests.append(
KGVespaChunkUpdateRequest(
document_id=kg_update_request.document_id,
chunk_id=kg_update_request.chunk_id,
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
update_request=kg_update_dict,
)
)
with self.httpx_client_context as httpx_client:
self._apply_kg_chunk_updates_batched(
processed_updates_requests, httpx_client
)
logger.debug(
"Finished updating Vespa documents in %.2f seconds",
time.monotonic() - update_start,
)
@retry(
tries=3,
delay=1,
@@ -962,14 +800,12 @@ class VespaIndex(DocumentIndex):
hybrid_alpha: float,
time_decay_multiplier: float,
num_to_retrieve: int,
ranking_profile_type: QueryExpansionType,
offset: int = 0,
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
) -> list[InferenceChunkUncleaned]:
vespa_where_clauses = build_vespa_filters(filters)
# Needs to be at least as much as the value set in Vespa schema config
target_hits = max(10 * num_to_retrieve, 1000)
yql = (
YQL_BASE.format(index_name=self.index_name)
+ vespa_where_clauses
@@ -981,11 +817,6 @@ class VespaIndex(DocumentIndex):
final_query = " ".join(final_keywords) if final_keywords else query
if ranking_profile_type == QueryExpansionType.KEYWORD:
ranking_profile = f"hybrid_search_keyword_base_{len(query_embedding)}"
else:
ranking_profile = f"hybrid_search_semantic_base_{len(query_embedding)}"
logger.debug(f"Query YQL: {yql}")
params: dict[str, str | int | float] = {
@@ -1001,7 +832,7 @@ class VespaIndex(DocumentIndex):
),
"hits": num_to_retrieve,
"offset": offset,
"ranking.profile": ranking_profile,
"ranking.profile": f"hybrid_search{len(query_embedding)}",
"timeout": VESPA_TIMEOUT,
}

View File

@@ -1,82 +0,0 @@
from pydantic import BaseModel
from retry import retry
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
from onyx.document_index.vespa.index import IndexFilters
from onyx.document_index.vespa.index import KGUChunkUpdateRequest
from onyx.document_index.vespa.index import VespaIndex
from onyx.utils.logger import setup_logger
# from backend.onyx.chat.process_message import get_inference_chunks
# from backend.onyx.document_index.vespa.index import VespaIndex
logger = setup_logger()
class KGChunkInfo(BaseModel):
kg_relationships: dict[str, int]
kg_entities: dict[str, int]
kg_terms: dict[str, int]
@retry(tries=3, delay=1, backoff=2)
def get_document_kg_info(
document_id: str,
index_name: str,
filters: IndexFilters | None = None,
) -> dict | None:
"""
Retrieve the kg_info attribute from a Vespa document by its document_id.
Args:
document_id: The unique identifier of the document.
index_name: The name of the Vespa index to query.
filters: Optional access control filters to apply.
Returns:
The kg_info dictionary if found, None otherwise.
"""
# Use the existing visit API infrastructure
kg_doc_info: dict[int, KGChunkInfo] = {}
document_chunks = _get_chunks_via_visit_api(
chunk_request=VespaChunkRequest(document_id=document_id),
index_name=index_name,
filters=filters or IndexFilters(access_control_list=None),
field_names=["kg_relationships", "kg_entities", "kg_terms"],
get_large_chunks=False,
)
for chunk_id, document_chunk in enumerate(document_chunks):
kg_chunk_info = KGChunkInfo(
kg_relationships=document_chunk["fields"].get("kg_relationships", {}),
kg_entities=document_chunk["fields"].get("kg_entities", {}),
kg_terms=document_chunk["fields"].get("kg_terms", {}),
)
kg_doc_info[chunk_id] = kg_chunk_info # TODO: check the chunk id is correct!
return kg_doc_info
@retry(tries=3, delay=1, backoff=2)
def update_kg_chunks_vespa_info(
kg_update_requests: list[KGUChunkUpdateRequest],
index_name: str,
tenant_id: str,
) -> None:
""" """
# Use the existing visit API infrastructure
vespa_index = VespaIndex(
index_name=index_name,
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=False,
multitenant=False,
httpx_client=None,
)
vespa_index.kg_chunk_updates(
kg_update_requests=kg_update_requests, tenant_id=tenant_id
)

View File

@@ -5,6 +5,7 @@ from datetime import timezone
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import CHUNK_ID
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
@@ -66,29 +67,6 @@ def build_vespa_filters(
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
def _build_kg_filter(
kg_entities: list[str] | None,
kg_relationships: list[str] | None,
kg_terms: list[str] | None,
) -> str:
if not kg_entities and not kg_relationships and not kg_terms:
return ""
filter_parts = []
# Process each filter type using the same pattern
for filter_type, values in [
("kg_entities", kg_entities),
("kg_relationships", kg_relationships),
("kg_terms", kg_terms),
]:
if values:
filter_parts.append(
" and ".join(f'({filter_type} contains "{val}") ' for val in values)
)
return f"({' and '.join(filter_parts)}) and "
# Start building the filter string
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
@@ -128,13 +106,6 @@ def build_vespa_filters(
# Time filter
filter_str += _build_time_filter(filters.time_cutoff)
# KG filter
filter_str += _build_kg_filter(
kg_entities=filters.kg_entities,
kg_relationships=filters.kg_relationships,
kg_terms=filters.kg_terms,
)
# Trim trailing " and "
if remove_trailing_and and filter_str.endswith(" and "):
filter_str = filter_str[:-5]

View File

@@ -159,8 +159,6 @@ def load_files_from_zip(
zip_metadata = json.load(metadata_file)
if isinstance(zip_metadata, list):
# convert list of dicts to dict of dicts
# Use just the basename for matching since metadata may not include
# the full path within the ZIP file
zip_metadata = {d["filename"]: d for d in zip_metadata}
except json.JSONDecodeError:
logger.warning(f"Unable to load {DANSWER_METADATA_FILENAME}")
@@ -178,13 +176,7 @@ def load_files_from_zip(
continue
with zip_file.open(file_info.filename, "r") as subfile:
# Try to match by exact filename first
if file_info.filename in zip_metadata:
yield file_info, subfile, zip_metadata.get(file_info.filename, {})
else:
# Then try matching by just the basename
basename = os.path.basename(file_info.filename)
yield file_info, subfile, zip_metadata.get(basename, {})
yield file_info, subfile, zip_metadata.get(file_info.filename, {})
def _extract_onyx_metadata(line: str) -> dict | None:

File diff suppressed because it is too large Load Diff

View File

@@ -1,229 +0,0 @@
from collections import defaultdict
from typing import Dict
from typing import List
from typing import Optional
import numpy as np
from thefuzz import process # type: ignore
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import get_entities_for_types
from onyx.db.relationships import get_relationships_for_entity_type_pairs
from onyx.kg.models import NormalizedEntities
from onyx.kg.models import NormalizedRelationships
from onyx.kg.models import NormalizedTerms
from onyx.kg.utils.embeddings import encode_string_batch
def _get_existing_normalized_entities(raw_entities: List[str]) -> List[str]:
"""
Get existing normalized entities from the database.
"""
entity_types = list(set([entity.split(":")[0] for entity in raw_entities]))
with get_session_with_current_tenant() as db_session:
entities = get_entities_for_types(db_session, entity_types)
return [entity.id_name for entity in entities]
def _get_existing_normalized_relationships(
raw_relationships: List[str],
) -> Dict[str, Dict[str, List[str]]]:
"""
Get existing normalized relationships from the database.
"""
relationship_type_map: Dict[str, Dict[str, List[str]]] = defaultdict(
lambda: defaultdict(list)
)
relationship_pairs = list(
set(
[
(
relationship.split("__")[0].split(":")[0],
relationship.split("__")[2].split(":")[0],
)
for relationship in raw_relationships
]
)
)
with get_session_with_current_tenant() as db_session:
relationships = get_relationships_for_entity_type_pairs(
db_session, relationship_pairs
)
for relationship in relationships:
relationship_type_map[relationship.source_entity_type_id_name][
relationship.target_entity_type_id_name
].append(relationship.id_name)
return relationship_type_map
def normalize_entities(raw_entities: List[str]) -> NormalizedEntities:
"""
Match each entity against a list of normalized entities using fuzzy matching.
Returns the best matching normalized entity for each input entity.
Args:
entities: List of entity strings to normalize
Returns:
List of normalized entity strings
"""
# Assume this is your predefined list of normalized entities
norm_entities = _get_existing_normalized_entities(raw_entities)
normalized_results: List[str] = []
normalized_map: Dict[str, str | None] = {}
threshold = 80 # Adjust threshold as needed
for entity in raw_entities:
# Find the best match and its score from norm_entities
best_match, score = process.extractOne(entity, norm_entities)
if score >= threshold:
normalized_results.append(best_match)
normalized_map[entity] = best_match
else:
# If no good match found, keep original
normalized_map[entity] = None
return NormalizedEntities(
entities=normalized_results, entity_normalization_map=normalized_map
)
def normalize_relationships(
raw_relationships: List[str], entity_normalization_map: Dict[str, Optional[str]]
) -> NormalizedRelationships:
"""
Normalize relationships using entity mappings and relationship string matching.
Args:
relationships: List of relationships in format "source__relation__target"
entity_normalization_map: Mapping of raw entities to normalized ones (or None)
Returns:
NormalizedRelationships containing normalized relationships and mapping
"""
# Placeholder for normalized relationship structure
nor_relationships = _get_existing_normalized_relationships(raw_relationships)
normalized_rels: List[str] = []
normalization_map: Dict[str, str | None] = {}
for raw_rel in raw_relationships:
# 1. Split and normalize entities
try:
source, rel_string, target = raw_rel.split("__")
except ValueError:
raise ValueError(f"Invalid relationship format: {raw_rel}")
# Check if entities are in normalization map and not None
norm_source = entity_normalization_map.get(source)
norm_target = entity_normalization_map.get(target)
if norm_source is None or norm_target is None:
normalization_map[raw_rel] = None
continue
# 2. Find candidate normalized relationships
candidate_rels = []
norm_source_type = norm_source.split(":")[0]
norm_target_type = norm_target.split(":")[0]
if (
norm_source_type in nor_relationships
and norm_target_type in nor_relationships[norm_source_type]
):
candidate_rels = [
rel.split("__")[1]
for rel in nor_relationships[norm_source_type][norm_target_type]
]
if not candidate_rels:
normalization_map[raw_rel] = None
continue
# 3. Encode and find best match
strings_to_encode = [rel_string] + candidate_rels
vectors = encode_string_batch(strings_to_encode)
# Get raw relation vector and candidate vectors
raw_vector = vectors[0]
candidate_vectors = vectors[1:]
# Calculate dot products
dot_products = np.dot(candidate_vectors, raw_vector)
best_match_idx = np.argmax(dot_products)
# Create normalized relationship
norm_rel = f"{norm_source}__{candidate_rels[best_match_idx]}__{norm_target}"
normalized_rels.append(norm_rel)
normalization_map[raw_rel] = norm_rel
return NormalizedRelationships(
relationships=normalized_rels, relationship_normalization_map=normalization_map
)
def normalize_terms(raw_terms: List[str]) -> NormalizedTerms:
"""
Normalize terms using semantic similarity matching.
Args:
terms: List of terms to normalize
Returns:
NormalizedTerms containing normalized terms and mapping
"""
# # Placeholder for normalized terms - this would typically come from a predefined list
# normalized_term_list = [
# "algorithm",
# "database",
# "software",
# "programming",
# # ... other normalized terms ...
# ]
# normalized_terms: List[str] = []
# normalization_map: Dict[str, str | None] = {}
# if not raw_terms:
# return NormalizedTerms(terms=[], term_normalization_map={})
# # Encode all terms at once for efficiency
# strings_to_encode = raw_terms + normalized_term_list
# vectors = encode_string_batch(strings_to_encode)
# # Split vectors into query terms and candidate terms
# query_vectors = vectors[:len(raw_terms)]
# candidate_vectors = vectors[len(raw_terms):]
# # Calculate similarity for each term
# for i, term in enumerate(raw_terms):
# # Calculate dot products with all candidates
# similarities = np.dot(candidate_vectors, query_vectors[i])
# best_match_idx = np.argmax(similarities)
# best_match_score = similarities[best_match_idx]
# # Use a threshold to determine if the match is good enough
# if best_match_score > 0.7: # Adjust threshold as needed
# normalized_term = normalized_term_list[best_match_idx]
# normalized_terms.append(normalized_term)
# normalization_map[term] = normalized_term
# else:
# # If no good match found, keep original
# normalization_map[term] = None
# return NormalizedTerms(
# terms=normalized_terms,
# term_normalization_map=normalization_map
# )
return NormalizedTerms(
terms=raw_terms, term_normalization_map={term: term for term in raw_terms}
)

View File

@@ -1,191 +0,0 @@
from onyx.configs.kg_configs import KG_IGNORE_EMAIL_DOMAINS
from onyx.configs.kg_configs import KG_OWN_COMPANY
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import get_kg_entity_by_document
from onyx.kg.context_preparations_extraction.models import ContextPreparation
from onyx.kg.context_preparations_extraction.models import (
KGDocumentClassificationPrompt,
)
from onyx.kg.models import KGChunkFormat
from onyx.kg.models import KGClassificationContent
from onyx.kg.utils.formatting_utils import generalize_entities
from onyx.kg.utils.formatting_utils import kg_email_processing
from onyx.prompts.kg_prompts import FIREFLIES_CHUNK_PREPROCESSING_PROMPT
from onyx.prompts.kg_prompts import FIREFLIES_DOCUMENT_CLASSIFICATION_PROMPT
def prepare_llm_content_fireflies(chunk: KGChunkFormat) -> ContextPreparation:
"""
Fireflies - prepare the content for the LLM.
"""
document_id = chunk.document_id
primary_owners = chunk.primary_owners
secondary_owners = chunk.secondary_owners
content = chunk.content
chunk.title.capitalize()
implied_entities = set()
implied_relationships = set()
with get_session_with_current_tenant() as db_session:
core_document = get_kg_entity_by_document(db_session, document_id)
if core_document:
core_document_id_name = core_document.id_name
else:
core_document_id_name = f"FIREFLIES:{document_id}"
# Do we need this here?
implied_entities.add(f"VENDOR:{KG_OWN_COMPANY}")
implied_entities.add(f"{core_document_id_name}")
implied_entities.add("FIREFLIES:*")
implied_relationships.add(
f"VENDOR:{KG_OWN_COMPANY}__relates_neutrally_to__{core_document_id_name}"
)
company_participant_emails = set()
account_participant_emails = set()
for owner in primary_owners + secondary_owners:
assert isinstance(KG_IGNORE_EMAIL_DOMAINS, list)
kg_owner = kg_email_processing(owner)
if any(
domain.lower() in kg_owner.company.lower()
for domain in KG_IGNORE_EMAIL_DOMAINS
):
continue
if kg_owner.employee:
company_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
if kg_owner.name not in implied_entities:
generalized_target_entity = list(
generalize_entities([core_document_id_name])
)[0]
implied_entities.add(f"EMPLOYEE:{kg_owner.name}")
implied_relationships.add(
f"EMPLOYEE:{kg_owner.name}__relates_neutrally_to__{core_document_id_name}"
)
implied_relationships.add(
f"EMPLOYEE:{kg_owner.name}__relates_neutrally_to__{generalized_target_entity}"
)
implied_relationships.add(
f"EMPLOYEE:*__relates_neutrally_to__{core_document_id_name}"
)
implied_relationships.add(
f"EMPLOYEE:*__relates_neutrally_to__{generalized_target_entity}"
)
if kg_owner.company not in implied_entities:
implied_entities.add(f"VENDOR:{kg_owner.company}")
implied_relationships.add(
f"VENDOR:{kg_owner.company}__relates_neutrally_to__{core_document_id_name}"
)
implied_relationships.add(
f"VENDOR:{kg_owner.company}__relates_neutrally_to__{generalized_target_entity}"
)
else:
account_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
if kg_owner.company not in implied_entities:
implied_entities.add(f"ACCOUNT:{kg_owner.company}")
implied_entities.add("ACCOUNT:*")
implied_relationships.add(
f"ACCOUNT:{kg_owner.company}__relates_neutrally_to__{core_document_id_name}"
)
implied_relationships.add(
f"ACCOUNT:*__relates_neutrally_to__{core_document_id_name}"
)
generalized_target_entity = list(
generalize_entities([core_document_id_name])
)[0]
implied_relationships.add(
f"ACCOUNT:*__relates_neutrally_to__{generalized_target_entity}"
)
implied_relationships.add(
f"ACCOUNT:{kg_owner.company}__relates_neutrally_to__{generalized_target_entity}"
)
participant_string = "\n - " + "\n - ".join(company_participant_emails)
account_participant_string = "\n - " + "\n - ".join(account_participant_emails)
llm_context = FIREFLIES_CHUNK_PREPROCESSING_PROMPT.format(
participant_string=participant_string,
account_participant_string=account_participant_string,
content=content,
)
return ContextPreparation(
llm_context=llm_context,
core_entity=core_document_id_name,
implied_entities=list(implied_entities),
implied_relationships=list(implied_relationships),
implied_terms=[],
)
def prepare_llm_document_content_fireflies(
document_classification_content: KGClassificationContent,
category_list: str,
category_definition_string: str,
) -> KGDocumentClassificationPrompt:
"""
Fireflies - prepare prompt for the LLM classification.
"""
prompt = FIREFLIES_DOCUMENT_CLASSIFICATION_PROMPT.format(
beginning_of_call_content=document_classification_content.classification_content,
category_list=category_list,
category_options=category_definition_string,
)
return KGDocumentClassificationPrompt(
llm_prompt=prompt,
)
def get_classification_content_from_fireflies_chunks(
first_num_classification_chunks: list[dict],
) -> str:
"""
Creates a KGClassificationContent object from a list of Fireflies chunks.
"""
assert isinstance(KG_IGNORE_EMAIL_DOMAINS, list)
primary_owners = first_num_classification_chunks[0]["fields"]["primary_owners"]
secondary_owners = first_num_classification_chunks[0]["fields"]["secondary_owners"]
company_participant_emails = set()
account_participant_emails = set()
for owner in primary_owners + secondary_owners:
kg_owner = kg_email_processing(owner)
if any(
domain.lower() in kg_owner.company.lower()
for domain in KG_IGNORE_EMAIL_DOMAINS
):
continue
if kg_owner.employee:
company_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
else:
account_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
participant_string = "\n - " + "\n - ".join(company_participant_emails)
account_participant_string = "\n - " + "\n - ".join(account_participant_emails)
title_string = first_num_classification_chunks[0]["fields"]["title"]
content_string = "\n".join(
[
chunk_content["fields"]["content"]
for chunk_content in first_num_classification_chunks
]
)
classification_content = f"{title_string}\n\nVendor Participants:\n{participant_string}\n\n\
Other Participants:\n{account_participant_string}\n\nBeginning of Call:\n{content_string}"
return classification_content

View File

@@ -1,21 +0,0 @@
from pydantic import BaseModel
class ContextPreparation(BaseModel):
"""
Context preparation format for the LLM KG extraction.
"""
llm_context: str
core_entity: str
implied_entities: list[str]
implied_relationships: list[str]
implied_terms: list[str]
class KGDocumentClassificationPrompt(BaseModel):
"""
Document classification prompt format for the LLM KG extraction.
"""
llm_prompt: str | None

View File

@@ -1,947 +0,0 @@
import json
from collections import defaultdict
from collections.abc import Callable
from typing import cast
from typing import Dict
from langchain_core.messages import HumanMessage
from onyx.db.connector import get_unprocessed_connector_ids
from onyx.db.document import get_document_updated_at
from onyx.db.document import get_unprocessed_kg_documents_for_connector
from onyx.db.document import update_document_kg_info
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import add_entity
from onyx.db.entities import get_entity_types
from onyx.db.relationships import add_or_increment_relationship
from onyx.db.relationships import add_relationship
from onyx.db.relationships import add_relationship_type
from onyx.document_index.vespa.index import KGUChunkUpdateRequest
from onyx.document_index.vespa.index import KGUDocumentUpdateRequest
from onyx.document_index.vespa.kg_interactions import update_kg_chunks_vespa_info
from onyx.kg.models import ConnectorExtractionStats
from onyx.kg.models import KGAggregatedExtractions
from onyx.kg.models import KGBatchExtractionStats
from onyx.kg.models import KGChunkExtraction
from onyx.kg.models import KGChunkFormat
from onyx.kg.models import KGChunkId
from onyx.kg.models import KGClassificationContent
from onyx.kg.models import KGClassificationDecisions
from onyx.kg.models import KGClassificationInstructionStrings
from onyx.kg.utils.chunk_preprocessing import prepare_llm_content
from onyx.kg.utils.chunk_preprocessing import prepare_llm_document_content
from onyx.kg.utils.formatting_utils import aggregate_kg_extractions
from onyx.kg.utils.formatting_utils import generalize_entities
from onyx.kg.utils.formatting_utils import generalize_relationships
from onyx.kg.vespa.vespa_interactions import get_document_chunks_for_kg_processing
from onyx.kg.vespa.vespa_interactions import (
get_document_classification_content_for_kg_processing,
)
from onyx.llm.factory import get_default_llms
from onyx.llm.utils import message_to_string
from onyx.prompts.kg_prompts import MASTER_EXTRACTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
def _get_classification_instructions() -> Dict[str, KGClassificationInstructionStrings]:
"""
Prepare the classification instructions for the given source.
"""
classification_instructions_dict: Dict[str, KGClassificationInstructionStrings] = {}
with get_session_with_current_tenant() as db_session:
entity_types = get_entity_types(db_session, active=None)
for entity_type in entity_types:
grounded_source_name = entity_type.grounded_source_name
if grounded_source_name is None:
continue
classification_class_definitions = entity_type.classification_requirements
classification_options = ", ".join(classification_class_definitions.keys())
classification_instructions_dict[
grounded_source_name
] = KGClassificationInstructionStrings(
classification_options=classification_options,
classification_class_definitions=classification_class_definitions,
)
return classification_instructions_dict
def get_entity_types_str(active: bool | None = None) -> str:
"""
Get the entity types from the KGChunkExtraction model.
"""
with get_session_with_current_tenant() as db_session:
active_entity_types = get_entity_types(db_session, active)
entity_types_list = []
for entity_type in active_entity_types:
if entity_type.description:
entity_description = "\n - Description: " + entity_type.description
else:
entity_description = ""
if entity_type.ge_determine_instructions:
allowed_options = "\n - Allowed Options: " + ", ".join(
entity_type.ge_determine_instructions
)
else:
allowed_options = ""
entity_types_list.append(
entity_type.id_name + entity_description + allowed_options
)
return "\n".join(entity_types_list)
def get_relationship_types_str(active: bool | None = None) -> str:
"""
Get the relationship types from the database.
Args:
active: Filter by active status (True, False, or None for all)
Returns:
A string with all relationship types formatted as "source_type__relationship_type__target_type"
"""
from onyx.db.relationships import get_all_relationship_types
with get_session_with_current_tenant() as db_session:
relationship_types = get_all_relationship_types(db_session)
# Filter by active status if specified
if active is not None:
relationship_types = [
rt for rt in relationship_types if rt.active == active
]
relationship_types_list = []
for rel_type in relationship_types:
# Format as "source_type__relationship_type__target_type"
formatted_type = f"{rel_type.source_entity_type_id_name}__{rel_type.type}__{rel_type.target_entity_type_id_name}"
relationship_types_list.append(formatted_type)
return "\n".join(relationship_types_list)
def kg_extraction_initialization(tenant_id: str, num_chunks: int = 1000) -> None:
"""
This extraction will create a random sample of chunks to process in order to perform
clustering and topic modeling.
"""
logger.info(f"Starting kg extraction for tenant {tenant_id}")
def kg_extraction(
tenant_id: str, index_name: str, processing_chunk_batch_size: int = 8
) -> list[ConnectorExtractionStats]:
"""
This extraction will try to extract from all chunks that have not been kg-processed yet.
Approach:
- Get all unprocessed connectors
- For each connector:
- Get all unprocessed documents
- Classify each document to select proper ones
- For each document:
- Get all chunks
- For each chunk:
- Extract entities, relationships, and terms
- make sure for each entity and relationship also the generalized versions are extracted!
- Aggregate results as needed
- Update Vespa and postgres
"""
logger.info(f"Starting kg extraction for tenant {tenant_id}")
with get_session_with_current_tenant() as db_session:
connector_ids = get_unprocessed_connector_ids(db_session)
connector_extraction_stats: list[ConnectorExtractionStats] = []
document_kg_updates: Dict[str, KGUDocumentUpdateRequest] = {}
processing_chunks: list[KGChunkFormat] = []
carryover_chunks: list[KGChunkFormat] = []
connector_aggregated_kg_extractions_list: list[KGAggregatedExtractions] = []
document_classification_instructions = _get_classification_instructions()
for connector_id in connector_ids:
connector_failed_chunk_extractions: list[KGChunkId] = []
connector_succeeded_chunk_extractions: list[KGChunkId] = []
connector_aggregated_kg_extractions: KGAggregatedExtractions = (
KGAggregatedExtractions(
grounded_entities_document_ids=defaultdict(str),
entities=defaultdict(int),
relationships=defaultdict(int),
terms=defaultdict(int),
)
)
with get_session_with_current_tenant() as db_session:
unprocessed_documents = get_unprocessed_kg_documents_for_connector(
db_session,
connector_id,
)
# TODO: restricted for testing only
unprocessed_documents_list = list(unprocessed_documents)[:6]
document_classification_content_list = (
get_document_classification_content_for_kg_processing(
[
unprocessed_document.id
for unprocessed_document in unprocessed_documents_list
],
index_name,
batch_size=processing_chunk_batch_size,
)
)
classification_outcomes: list[tuple[bool, KGClassificationDecisions]] = []
for document_classification_content in document_classification_content_list:
classification_outcomes.extend(
_kg_document_classification(
document_classification_content,
document_classification_instructions,
)
)
documents_to_process = []
for document_to_process, document_classification_outcome in zip(
unprocessed_documents_list, classification_outcomes
):
if (
document_classification_outcome[0]
and document_classification_outcome[1].classification_decision
):
documents_to_process.append(document_to_process)
for document_to_process in documents_to_process:
formatted_chunk_batches = get_document_chunks_for_kg_processing(
document_to_process.id,
index_name,
batch_size=processing_chunk_batch_size,
)
formatted_chunk_batches_list = list(formatted_chunk_batches)
for formatted_chunk_batch in formatted_chunk_batches_list:
processing_chunks.extend(formatted_chunk_batch)
if len(processing_chunks) >= processing_chunk_batch_size:
carryover_chunks.extend(
processing_chunks[processing_chunk_batch_size:]
)
processing_chunks = processing_chunks[:processing_chunk_batch_size]
chunk_processing_batch_results = _kg_chunk_batch_extraction(
processing_chunks, index_name, tenant_id
)
# Consider removing the stats expressions here and rather write to the db(?)
connector_failed_chunk_extractions.extend(
chunk_processing_batch_results.failed
)
connector_succeeded_chunk_extractions.extend(
chunk_processing_batch_results.succeeded
)
aggregated_batch_extractions = (
chunk_processing_batch_results.aggregated_kg_extractions
)
# Update grounded_entities_document_ids (replace values)
connector_aggregated_kg_extractions.grounded_entities_document_ids.update(
aggregated_batch_extractions.grounded_entities_document_ids
)
# Add to entity counts instead of replacing
for entity, count in aggregated_batch_extractions.entities.items():
connector_aggregated_kg_extractions.entities[entity] += count
# Add to relationship counts instead of replacing
for (
relationship,
count,
) in aggregated_batch_extractions.relationships.items():
connector_aggregated_kg_extractions.relationships[
relationship
] += count
# Add to term counts instead of replacing
for term, count in aggregated_batch_extractions.terms.items():
connector_aggregated_kg_extractions.terms[term] += count
connector_extraction_stats.append(
ConnectorExtractionStats(
connector_id=connector_id,
num_failed=len(connector_failed_chunk_extractions),
num_succeeded=len(connector_succeeded_chunk_extractions),
num_processed=len(processing_chunks),
)
)
processing_chunks = carryover_chunks.copy()
carryover_chunks = []
# processes remaining chunks
chunk_processing_batch_results = _kg_chunk_batch_extraction(
processing_chunks, index_name, tenant_id
)
# Consider removing the stats expressions here and rather write to the db(?)
connector_failed_chunk_extractions.extend(chunk_processing_batch_results.failed)
connector_succeeded_chunk_extractions.extend(
chunk_processing_batch_results.succeeded
)
aggregated_batch_extractions = (
chunk_processing_batch_results.aggregated_kg_extractions
)
# Update grounded_entities_document_ids (replace values)
connector_aggregated_kg_extractions.grounded_entities_document_ids.update(
aggregated_batch_extractions.grounded_entities_document_ids
)
# Add to entity counts instead of replacing
for entity, count in aggregated_batch_extractions.entities.items():
connector_aggregated_kg_extractions.entities[entity] += count
# Add to relationship counts instead of replacing
for relationship, count in aggregated_batch_extractions.relationships.items():
connector_aggregated_kg_extractions.relationships[relationship] += count
# Add to term counts instead of replacing
for term, count in aggregated_batch_extractions.terms.items():
connector_aggregated_kg_extractions.terms[term] += count
connector_extraction_stats.append(
ConnectorExtractionStats(
connector_id=connector_id,
num_failed=len(connector_failed_chunk_extractions),
num_succeeded=len(connector_succeeded_chunk_extractions),
num_processed=len(processing_chunks),
)
)
processing_chunks = []
carryover_chunks = []
connector_aggregated_kg_extractions_list.append(
connector_aggregated_kg_extractions
)
# aggregate document updates
for (
processed_document
) in (
unprocessed_documents_list
): # This will need to change if we do not materialize docs
document_kg_updates[processed_document.id] = KGUDocumentUpdateRequest(
document_id=processed_document.id,
entities=set(),
relationships=set(),
terms=set(),
)
updated_chunk_batches = get_document_chunks_for_kg_processing(
processed_document.id,
index_name,
batch_size=processing_chunk_batch_size,
)
for updated_chunk_batch in updated_chunk_batches:
for updated_chunk in updated_chunk_batch:
chunk_entities = updated_chunk.entities
chunk_relationships = updated_chunk.relationships
chunk_terms = updated_chunk.terms
document_kg_updates[processed_document.id].entities.update(
chunk_entities
)
document_kg_updates[processed_document.id].relationships.update(
chunk_relationships
)
document_kg_updates[processed_document.id].terms.update(chunk_terms)
aggregated_kg_extractions = aggregate_kg_extractions(
connector_aggregated_kg_extractions_list
)
with get_session_with_current_tenant() as db_session:
tracked_entity_types = [
x.id_name for x in get_entity_types(db_session, active=None)
]
# Populate the KG database with the extracted entities, relationships, and terms
for (
entity,
extraction_count,
) in aggregated_kg_extractions.entities.items():
if len(entity.split(":")) != 2:
logger.error(
f"Invalid entity {entity} in aggregated_kg_extractions.entities"
)
continue
entity_type, entity_name = entity.split(":")
entity_type = entity_type.upper()
entity_name = entity_name.capitalize()
if entity_type not in tracked_entity_types:
continue
try:
with get_session_with_current_tenant() as db_session:
if (
entity
not in aggregated_kg_extractions.grounded_entities_document_ids
):
add_entity(
db_session=db_session,
entity_type=entity_type,
name=entity_name,
cluster_count=extraction_count,
)
else:
event_time = get_document_updated_at(
entity,
db_session,
)
add_entity(
db_session=db_session,
entity_type=entity_type,
name=entity_name,
cluster_count=extraction_count,
document_id=aggregated_kg_extractions.grounded_entities_document_ids[
entity
],
event_time=event_time,
)
db_session.commit()
except Exception as e:
logger.error(f"Error adding entity {entity} to the database: {e}")
relationship_type_counter: dict[str, int] = defaultdict(int)
for (
relationship,
extraction_count,
) in aggregated_kg_extractions.relationships.items():
relationship_split = relationship.split("__")
source_entity, relationship_type_, target_entity = relationship.split("__")
source_entity = relationship_split[0]
relationship_type = " ".join(relationship_split[1:-1]).replace("__", "_")
target_entity = relationship_split[-1]
source_entity_type = source_entity.split(":")[0]
target_entity_type = target_entity.split(":")[0]
if (
source_entity_type not in tracked_entity_types
or target_entity_type not in tracked_entity_types
):
continue
source_entity_general = f"{source_entity_type.upper()}"
target_entity_general = f"{target_entity_type.upper()}"
relationship_type_id_name = (
f"{source_entity_general}__{relationship_type.lower()}__"
f"{target_entity_general}"
)
relationship_type_counter[relationship_type_id_name] += extraction_count
for (
relationship_type_id_name,
extraction_count,
) in relationship_type_counter.items():
(
source_entity_type,
relationship_type,
target_entity_type,
) = relationship_type_id_name.split("__")
if (
source_entity_type not in tracked_entity_types
or target_entity_type not in tracked_entity_types
):
continue
try:
with get_session_with_current_tenant() as db_session:
try:
add_relationship_type(
db_session=db_session,
source_entity_type=source_entity_type.upper(),
relationship_type=relationship_type,
target_entity_type=target_entity_type.upper(),
definition=False,
extraction_count=extraction_count,
)
db_session.commit()
except Exception as e:
logger.error(
f"Error adding relationship type {relationship_type_id_name} to the database: {e}"
)
except Exception as e:
logger.error(
f"Error adding relationship type {relationship_type_id_name} to the database: {e}"
)
for (
relationship,
extraction_count,
) in aggregated_kg_extractions.relationships.items():
relationship_split = relationship.split("__")
source_entity, relationship_type_, target_entity = relationship.split("__")
source_entity = relationship_split[0]
relationship_type = (
" ".join(relationship_split[1:-1]).replace("__", " ").replace("_", " ")
)
target_entity = relationship_split[-1]
source_entity_type = source_entity.split(":")[0]
target_entity_type = target_entity.split(":")[0]
try:
with get_session_with_current_tenant() as db_session:
add_relationship(db_session, relationship, extraction_count)
db_session.commit()
except Exception as e:
logger.error(
f"Error adding relationship {relationship} to the database: {e}"
)
with get_session_with_current_tenant() as db_session:
add_or_increment_relationship(db_session, relationship)
db_session.commit()
# Populate the Documents table with the kg information for the documents
for document_id, document_kg_update in document_kg_updates.items():
with get_session_with_current_tenant() as db_session:
update_document_kg_info(
db_session,
document_id,
kg_processed=True,
kg_data={
"entities": list(document_kg_update.entities),
"relationships": list(document_kg_update.relationships),
"terms": list(document_kg_update.terms),
},
)
db_session.commit()
return connector_extraction_stats
def _kg_chunk_batch_extraction(
chunks: list[KGChunkFormat],
index_name: str,
tenant_id: str,
) -> KGBatchExtractionStats:
_, fast_llm = get_default_llms()
succeeded_chunk_id: list[KGChunkId] = []
failed_chunk_id: list[KGChunkId] = []
succeeded_chunk_extraction: list[KGChunkExtraction] = []
preformatted_prompt = MASTER_EXTRACTION_PROMPT.format(
entity_types=get_entity_types_str(active=True)
)
def process_single_chunk(
chunk: KGChunkFormat, preformatted_prompt: str
) -> tuple[bool, KGUChunkUpdateRequest]:
"""Process a single chunk and return success status and chunk ID."""
# For now, we're just processing the content
# TODO: Implement actual prompt application logic
llm_preprocessing = prepare_llm_content(chunk)
formatted_prompt = preformatted_prompt.replace(
"---content---", llm_preprocessing.llm_context
)
msg = [
HumanMessage(
content=formatted_prompt,
)
]
try:
logger.info(
f"LLM Extraction from chunk {chunk.chunk_id} from doc {chunk.document_id}"
)
raw_extraction_result = fast_llm.invoke(msg)
extraction_result = message_to_string(raw_extraction_result)
try:
cleaned_result = (
extraction_result.replace("```json", "").replace("```", "").strip()
)
parsed_result = json.loads(cleaned_result)
extracted_entities = parsed_result.get("entities", [])
extracted_relationships = parsed_result.get("relationships", [])
extracted_terms = parsed_result.get("terms", [])
implied_extracted_relationships = [
llm_preprocessing.core_entity
+ "__"
+ "relates_neutrally_to"
+ "__"
+ entity
for entity in extracted_entities
]
all_entities = set(
list(extracted_entities)
+ list(llm_preprocessing.implied_entities)
+ list(
generalize_entities(
extracted_entities + llm_preprocessing.implied_entities
)
)
)
logger.info(f"All entities: {all_entities}")
all_relationships = (
extracted_relationships
+ llm_preprocessing.implied_relationships
+ implied_extracted_relationships
)
all_relationships = set(
list(all_relationships)
+ list(generalize_relationships(all_relationships))
)
kg_updates = [
KGUChunkUpdateRequest(
document_id=chunk.document_id,
chunk_id=chunk.chunk_id,
core_entity=llm_preprocessing.core_entity,
entities=all_entities,
relationships=all_relationships,
terms=set(extracted_terms),
),
]
update_kg_chunks_vespa_info(
kg_update_requests=kg_updates,
index_name=index_name,
tenant_id=tenant_id,
)
logger.info(
f"KG updated: {chunk.chunk_id} from doc {chunk.document_id}"
)
return True, kg_updates[0] # only single chunk
except json.JSONDecodeError as e:
logger.error(
f"Invalid JSON format for extraction of chunk {chunk.chunk_id} \
from doc {chunk.document_id}: {str(e)}"
)
logger.error(f"Raw output: {extraction_result}")
return False, KGUChunkUpdateRequest(
document_id=chunk.document_id,
chunk_id=chunk.chunk_id,
core_entity=llm_preprocessing.core_entity,
entities=set(),
relationships=set(),
terms=set(),
)
except Exception as e:
logger.error(
f"Failed to process chunk {chunk.chunk_id} from doc {chunk.document_id}: {str(e)}"
)
return False, KGUChunkUpdateRequest(
document_id=chunk.document_id,
chunk_id=chunk.chunk_id,
core_entity="",
entities=set(),
relationships=set(),
terms=set(),
)
# Assume for prototype: use_threads = True. TODO: Make thread safe!
functions_with_args: list[tuple[Callable, tuple]] = [
(process_single_chunk, (chunk, preformatted_prompt)) for chunk in chunks
]
logger.debug("Running KG extraction on chunks in parallel")
results = run_functions_tuples_in_parallel(functions_with_args, allow_failures=True)
# Sort results into succeeded and failed
for success, chunk_results in results:
if success:
succeeded_chunk_id.append(
KGChunkId(
document_id=chunk_results.document_id,
chunk_id=chunk_results.chunk_id,
)
)
succeeded_chunk_extraction.append(chunk_results)
else:
failed_chunk_id.append(
KGChunkId(
document_id=chunk_results.document_id,
chunk_id=chunk_results.chunk_id,
)
)
# Collect data for postgres later on
aggregated_kg_extractions = KGAggregatedExtractions(
grounded_entities_document_ids=defaultdict(str),
entities=defaultdict(int),
relationships=defaultdict(int),
terms=defaultdict(int),
)
for chunk_result in succeeded_chunk_extraction:
aggregated_kg_extractions.grounded_entities_document_ids[
chunk_result.core_entity
] = chunk_result.document_id
mentioned_chunk_entities: set[str] = set()
for relationship in chunk_result.relationships:
relationship_split = relationship.split("__")
if len(relationship_split) == 3:
if relationship_split[0] not in mentioned_chunk_entities:
aggregated_kg_extractions.entities[relationship_split[0]] += 1
mentioned_chunk_entities.add(relationship_split[0])
if relationship_split[2] not in mentioned_chunk_entities:
aggregated_kg_extractions.entities[relationship_split[2]] += 1
mentioned_chunk_entities.add(relationship_split[2])
aggregated_kg_extractions.relationships[relationship] += 1
for kg_entity in chunk_result.entities:
if kg_entity not in mentioned_chunk_entities:
aggregated_kg_extractions.entities[kg_entity] += 1
mentioned_chunk_entities.add(kg_entity)
for kg_term in chunk_result.terms:
aggregated_kg_extractions.terms[kg_term] += 1
return KGBatchExtractionStats(
connector_id=chunks[0].connector_id if chunks else None, # TODO: Update!
succeeded=succeeded_chunk_id,
failed=failed_chunk_id,
aggregated_kg_extractions=aggregated_kg_extractions,
)
def _kg_document_classification(
document_classification_content_list: list[KGClassificationContent],
classification_instructions: dict[str, KGClassificationInstructionStrings],
) -> list[tuple[bool, KGClassificationDecisions]]:
primary_llm, fast_llm = get_default_llms()
def classify_single_document(
document_classification_content: KGClassificationContent,
classification_instructions: dict[str, KGClassificationInstructionStrings],
) -> tuple[bool, KGClassificationDecisions]:
"""Classify a single document whether it should be kg-processed or not"""
source = document_classification_content.source_type
document_id = document_classification_content.document_id
if source not in classification_instructions:
logger.info(
f"Source {source} did not have kg classification instructions. No content analysis."
)
return False, KGClassificationDecisions(
document_id=document_id,
classification_decision=False,
classification_class=None,
)
classification_prompt = prepare_llm_document_content(
document_classification_content,
category_list=classification_instructions[source].classification_options,
category_definitions=classification_instructions[
source
].classification_class_definitions,
)
if classification_prompt.llm_prompt is None:
logger.info(
f"Source {source} did not have kg document classification instructions. No content analysis."
)
return False, KGClassificationDecisions(
document_id=document_id,
classification_decision=False,
classification_class=None,
)
msg = [
HumanMessage(
content=classification_prompt.llm_prompt,
)
]
try:
logger.info(
f"LLM Classification from document {document_classification_content.document_id}"
)
raw_classification_result = primary_llm.invoke(msg)
classification_result = (
message_to_string(raw_classification_result)
.replace("```json", "")
.replace("```", "")
.strip()
)
classification_class = classification_result.split("CATEGORY:")[1].strip()
if (
classification_class
in classification_instructions[source].classification_class_definitions
):
extraction_decision = cast(
bool,
classification_instructions[
source
].classification_class_definitions[classification_class][
"extraction"
],
)
else:
extraction_decision = False
return True, KGClassificationDecisions(
document_id=document_id,
classification_decision=extraction_decision,
classification_class=classification_class,
)
except Exception as e:
logger.error(
f"Failed to classify document {document_classification_content.document_id}: {str(e)}"
)
return False, KGClassificationDecisions(
document_id=document_id,
classification_decision=False,
classification_class=None,
)
# Assume for prototype: use_threads = True. TODO: Make thread safe!
functions_with_args: list[tuple[Callable, tuple]] = [
(
classify_single_document,
(document_classification_content, classification_instructions),
)
for document_classification_content in document_classification_content_list
]
logger.debug("Running KG classification on documents in parallel")
results = run_functions_tuples_in_parallel(functions_with_args, allow_failures=True)
return results
# logger.debug("Running KG extraction on chunks in parallel")
# results = run_functions_tuples_in_parallel(functions_with_args, allow_failures=True)
# # Sort results into succeeded and failed
# for success, chunk_results in results:
# if success:
# succeeded_chunk_id.append(
# KGChunkId(
# document_id=chunk_results.document_id,
# chunk_id=chunk_results.chunk_id,
# )
# )
# succeeded_chunk_extraction.append(chunk_results)
# else:
# failed_chunk_id.append(
# KGChunkId(
# document_id=chunk_results.document_id,
# chunk_id=chunk_results.chunk_id,
# )
# )
# # Collect data for postgres later on
# aggregated_kg_extractions = KGAggregatedExtractions(
# grounded_entities_document_ids=defaultdict(str),
# entities=defaultdict(int),
# relationships=defaultdict(int),
# terms=defaultdict(int),
# )
# for chunk_result in succeeded_chunk_extraction:
# aggregated_kg_extractions.grounded_entities_document_ids[
# chunk_result.core_entity
# ] = chunk_result.document_id
# mentioned_chunk_entities: set[str] = set()
# for relationship in chunk_result.relationships:
# relationship_split = relationship.split("__")
# if len(relationship_split) == 3:
# if relationship_split[0] not in mentioned_chunk_entities:
# aggregated_kg_extractions.entities[relationship_split[0]] += 1
# mentioned_chunk_entities.add(relationship_split[0])
# if relationship_split[2] not in mentioned_chunk_entities:
# aggregated_kg_extractions.entities[relationship_split[2]] += 1
# mentioned_chunk_entities.add(relationship_split[2])
# aggregated_kg_extractions.relationships[relationship] += 1
# for kg_entity in chunk_result.entities:
# if kg_entity not in mentioned_chunk_entities:
# aggregated_kg_extractions.entities[kg_entity] += 1
# mentioned_chunk_entities.add(kg_entity)
# for kg_term in chunk_result.terms:
# aggregated_kg_extractions.terms[kg_term] += 1
# return KGBatchExtractionStats(
# connector_id=chunks[0].connector_id if chunks else None, # TODO: Update!
# succeeded=succeeded_chunk_id,
# failed=failed_chunk_id,
# aggregated_kg_extractions=aggregated_kg_extractions,
# )
def _kg_connector_extraction(
connector_id: str,
tenant_id: str,
) -> None:
logger.info(
f"Starting kg extraction for connector {connector_id} for tenant {tenant_id}"
)
# - grab kg type data from postgres
# - construct prompt
# find all documents for the connector that have not been kg-processed
# - loop for :
# - grab a number of chunks from vespa
# - convert them into the KGChunk format
# - run the extractions in parallel
# - save the results
# - mark chunks as processed
# - update the connector status
#

View File

@@ -1,86 +0,0 @@
from onyx.utils.logger import setup_logger
logger = setup_logger()
"""
def kg_extraction_section(
kg_chunks: list[KGChunk],
llm: LLM,
max_workers: int = 10,
) -> list[KGChunkExtraction]:
def _get_metadata_str(metadata: dict[str, str | list[str]]) -> str:
metadata_str = "\nMetadata:\n"
for key, value in metadata.items():
value_str = ", ".join(value) if isinstance(value, list) else value
metadata_str += f"{key} - {value_str}\n"
return metadata_str
def _get_usefulness_messages() -> list[dict[str, str]]:
metadata_str = _get_metadata_str(metadata) if metadata else ""
messages = [
{
"role": "user",
"content": SECTION_FILTER_PROMPT.format(
title=title.replace("\n", " "),
chunk_text=section_content,
user_query=query,
optional_metadata=metadata_str,
),
},
]
return messages
def _extract_usefulness(model_output: str) -> bool:
if model_output.strip().strip('"').lower() == NONUSEFUL_PAT.lower():
return False
return True
messages = _get_usefulness_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = message_to_string(llm.invoke(filled_llm_prompt))
logger.debug(model_output)
return _extract_usefulness(model_output)
"""
"""
def llm_batch_eval_sections(
query: str,
section_contents: list[str],
llm: LLM,
titles: list[str],
metadata_list: list[dict[str, str | list[str]]],
use_threads: bool = True,
) -> list[bool]:
if DISABLE_LLM_DOC_RELEVANCE:
raise RuntimeError(
"LLM Doc Relevance is globally disabled, "
"this should have been caught upstream."
)
if use_threads:
functions_with_args: list[tuple[Callable, tuple]] = [
(llm_eval_section, (query, section_content, llm, title, metadata))
for section_content, title, metadata in zip(
section_contents, titles, metadata_list
)
]
logger.debug(
"Running LLM usefulness eval in parallel (following logging may be out of order)"
)
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=True
)
# In case of failure/timeout, don't throw out the section
return [True if item is None else item for item in parallel_results]
else:
return [
llm_eval_section(query, section_content, llm, title, metadata)
for section_content, title, metadata in zip(
section_contents, titles, metadata_list
)
]
"""

View File

@@ -1,99 +0,0 @@
from collections import defaultdict
from typing import Dict
from pydantic import BaseModel
class KGChunkFormat(BaseModel):
connector_id: int | None = None
document_id: str
chunk_id: int
title: str
content: str
primary_owners: list[str]
secondary_owners: list[str]
source_type: str
metadata: dict[str, str | list[str]] | None = None
entities: Dict[str, int] = {}
relationships: Dict[str, int] = {}
terms: Dict[str, int] = {}
class KGChunkExtraction(BaseModel):
connector_id: int
document_id: str
chunk_id: int
core_entity: str
entities: list[str]
relationships: list[str]
terms: list[str]
class KGChunkId(BaseModel):
connector_id: int | None = None
document_id: str
chunk_id: int
class KGAggregatedExtractions(BaseModel):
grounded_entities_document_ids: defaultdict[str, str]
entities: defaultdict[str, int]
relationships: defaultdict[str, int]
terms: defaultdict[str, int]
class KGBatchExtractionStats(BaseModel):
connector_id: int | None = None
succeeded: list[KGChunkId]
failed: list[KGChunkId]
aggregated_kg_extractions: KGAggregatedExtractions
class ConnectorExtractionStats(BaseModel):
connector_id: int
num_succeeded: int
num_failed: int
num_processed: int
class KGPerson(BaseModel):
name: str
company: str
employee: bool
class NormalizedEntities(BaseModel):
entities: list[str]
entity_normalization_map: dict[str, str | None]
class NormalizedRelationships(BaseModel):
relationships: list[str]
relationship_normalization_map: dict[str, str | None]
class NormalizedTerms(BaseModel):
terms: list[str]
term_normalization_map: dict[str, str | None]
class KGClassificationContent(BaseModel):
document_id: str
classification_content: str
source_type: str
class KGClassificationDecisions(BaseModel):
document_id: str
classification_decision: bool
classification_class: str | None
class KGClassificationRule(BaseModel):
description: str
extration: bool
class KGClassificationInstructionStrings(BaseModel):
classification_options: str
classification_class_definitions: dict[str, Dict[str, str | bool]]

View File

@@ -1,55 +0,0 @@
from typing import Dict
from onyx.kg.context_preparations_extraction.fireflies import (
prepare_llm_content_fireflies,
)
from onyx.kg.context_preparations_extraction.fireflies import (
prepare_llm_document_content_fireflies,
)
from onyx.kg.context_preparations_extraction.models import ContextPreparation
from onyx.kg.context_preparations_extraction.models import (
KGDocumentClassificationPrompt,
)
from onyx.kg.models import KGChunkFormat
from onyx.kg.models import KGClassificationContent
def prepare_llm_content(chunk: KGChunkFormat) -> ContextPreparation:
"""
Prepare the content for the LLM.
"""
if chunk.source_type == "fireflies":
return prepare_llm_content_fireflies(chunk)
else:
return ContextPreparation(
llm_context=chunk.content,
core_entity="",
implied_entities=[],
implied_relationships=[],
implied_terms=[],
)
def prepare_llm_document_content(
document_classification_content: KGClassificationContent,
category_list: str,
category_definitions: dict[str, Dict[str, str | bool]],
) -> KGDocumentClassificationPrompt:
"""
Prepare the content for the extraction classification.
"""
category_definition_string = ""
for category, category_data in category_definitions.items():
category_definition_string += f"{category}: {category_data['description']}\n"
if document_classification_content.source_type == "fireflies":
return prepare_llm_document_content_fireflies(
document_classification_content, category_list, category_definition_string
)
else:
return KGDocumentClassificationPrompt(
llm_prompt=None,
)

View File

@@ -1,41 +0,0 @@
from sqlalchemy import select
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import Connector
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_unprocessed_connector_ids(tenant_id: str) -> list[int]:
"""
Retrieves a list of connector IDs that have not been KG processed for a given tenant.
Args:
tenant_id (str): The ID of the tenant to check for unprocessed connectors
Returns:
list[int]: List of connector IDs that have enabled KG extraction but have unprocessed documents
"""
try:
with get_session_with_current_tenant() as db_session:
# Find connectors that:
# 1. Have KG extraction enabled
# 2. Have documents that haven't been KG processed
stmt = (
select(Connector.id)
.distinct()
.join(DocumentByConnectorCredentialPair)
.where(
Connector.kg_extraction_enabled,
DocumentByConnectorCredentialPair.has_been_kg_processed.is_(False),
)
)
result = db_session.execute(stmt)
return [row[0] for row in result.fetchall()]
except Exception as e:
logger.error(f"Error fetching unprocessed connector IDs: {str(e)}")
return []

View File

@@ -1,23 +0,0 @@
from typing import List
import numpy as np
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.search_settings import get_current_search_settings
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import EmbedTextType
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
def encode_string_batch(strings: List[str]) -> np.ndarray:
with get_session_with_current_tenant() as db_session:
current_search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=current_search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
# Get embeddings while session is still open
embedding = model.encode(strings, text_type=EmbedTextType.QUERY)
return np.array(embedding)

View File

@@ -1,123 +0,0 @@
from collections import defaultdict
from onyx.configs.kg_configs import KG_OWN_COMPANY
from onyx.configs.kg_configs import KG_OWN_EMAIL_DOMAINS
from onyx.kg.models import KGAggregatedExtractions
from onyx.kg.models import KGPerson
def format_entity(entity: str) -> str:
if len(entity.split(":")) == 2:
entity_type, entity_name = entity.split(":")
return f"{entity_type.upper()}:{entity_name.title()}"
else:
return entity
def format_relationship(relationship: str) -> str:
source_node, relationship_type, target_node = relationship.split("__")
return (
f"{format_entity(source_node)}__"
f"{relationship_type.lower()}__"
f"{format_entity(target_node)}"
)
def format_relationship_type(relationship_type: str) -> str:
source_node_type, relationship_type, target_node_type = relationship_type.split(
"__"
)
return (
f"{source_node_type.upper()}__"
f"{relationship_type.lower()}__"
f"{target_node_type.upper()}"
)
def generate_relationship_type(relationship: str) -> str:
source_node, relationship_type, target_node = relationship.split("__")
return (
f"{source_node.split(':')[0].upper()}__"
f"{relationship_type.lower()}__"
f"{target_node.split(':')[0].upper()}"
)
def aggregate_kg_extractions(
connector_aggregated_kg_extractions_list: list[KGAggregatedExtractions],
) -> KGAggregatedExtractions:
aggregated_kg_extractions = KGAggregatedExtractions(
grounded_entities_document_ids=defaultdict(str),
entities=defaultdict(int),
relationships=defaultdict(int),
terms=defaultdict(int),
)
for connector_aggregated_kg_extractions in connector_aggregated_kg_extractions_list:
for (
grounded_entity,
document_id,
) in connector_aggregated_kg_extractions.grounded_entities_document_ids.items():
aggregated_kg_extractions.grounded_entities_document_ids[
grounded_entity
] = document_id
for entity, count in connector_aggregated_kg_extractions.entities.items():
aggregated_kg_extractions.entities[entity] += count
for (
relationship,
count,
) in connector_aggregated_kg_extractions.relationships.items():
aggregated_kg_extractions.relationships[relationship] += count
for term, count in connector_aggregated_kg_extractions.terms.items():
aggregated_kg_extractions.terms[term] += count
return aggregated_kg_extractions
def kg_email_processing(email: str) -> KGPerson:
"""
Process the email.
"""
name, company_domain = email.split("@")
assert isinstance(KG_OWN_EMAIL_DOMAINS, list)
employee = any(domain in company_domain for domain in KG_OWN_EMAIL_DOMAINS)
if employee:
company = KG_OWN_COMPANY
else:
company = company_domain.capitalize()
return KGPerson(name=name, company=company, employee=employee)
def generalize_entities(entities: list[str]) -> set[str]:
"""
Generalize entities to their superclass.
"""
return set([f"{entity.split(':')[0]}:*" for entity in entities])
def generalize_relationships(relationships: list[str]) -> set[str]:
"""
Generalize relationships to their superclass.
"""
generalized_relationships: set[str] = set()
for relationship in relationships:
assert (
len(relationship.split("__")) == 3
), "Relationship is not in the correct format"
source_entity, relationship_type, target_entity = relationship.split("__")
generalized_source_entity = list(generalize_entities([source_entity]))[0]
generalized_target_entity = list(generalize_entities([target_entity]))[0]
generalized_relationships.add(
f"{generalized_source_entity}__{relationship_type}__{target_entity}"
)
generalized_relationships.add(
f"{source_entity}__{relationship_type}__{generalized_target_entity}"
)
generalized_relationships.add(
f"{generalized_source_entity}__{relationship_type}__{generalized_target_entity}"
)
return generalized_relationships

View File

@@ -1,179 +0,0 @@
import json
from collections.abc import Generator
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
from onyx.document_index.vespa.index import IndexFilters
from onyx.kg.context_preparations_extraction.fireflies import (
get_classification_content_from_fireflies_chunks,
)
from onyx.kg.models import KGChunkFormat
from onyx.kg.models import KGClassificationContent
def get_document_classification_content_for_kg_processing(
document_ids: list[str],
index_name: str,
batch_size: int = 8,
num_classification_chunks: int = 3,
) -> Generator[list[KGClassificationContent], None, None]:
"""
Generates the content used for initial classification of a document from
the first num_classification_chunks chunks.
"""
classification_content_list: list[KGClassificationContent] = []
for i in range(0, len(document_ids), batch_size):
batch_document_ids = document_ids[i : i + batch_size]
for document_id in batch_document_ids:
# ... existing code for getting chunks and processing ...
first_num_classification_chunks: list[dict] = _get_chunks_via_visit_api(
chunk_request=VespaChunkRequest(
document_id=document_id,
max_chunk_ind=num_classification_chunks - 1,
min_chunk_ind=0,
),
index_name=index_name,
filters=IndexFilters(access_control_list=None),
field_names=[
"document_id",
"chunk_id",
"title",
"content",
"metadata",
"source_type",
"primary_owners",
"secondary_owners",
],
get_large_chunks=False,
)
first_num_classification_chunks = sorted(
first_num_classification_chunks, key=lambda x: x["fields"]["chunk_id"]
)[:num_classification_chunks]
classification_content = _get_classification_content_from_chunks(
first_num_classification_chunks
)
classification_content_list.append(
KGClassificationContent(
document_id=document_id,
classification_content=classification_content,
source_type=first_num_classification_chunks[0]["fields"][
"source_type"
],
)
)
# Yield the batch of classification content
if classification_content_list:
yield classification_content_list
classification_content_list = []
# Yield any remaining items
if classification_content_list:
yield classification_content_list
def get_document_chunks_for_kg_processing(
document_id: str,
index_name: str,
batch_size: int = 8,
) -> Generator[list[KGChunkFormat], None, None]:
"""
Retrieves chunks from Vespa for the given document IDs and converts them to KGChunks.
Args:
document_ids (list[str]): List of document IDs to fetch chunks for
index_name (str): Name of the Vespa index
tenant_id (str): ID of the tenant
Yields:
list[KGChunk]: Batches of chunks ready for KG processing
"""
current_batch: list[KGChunkFormat] = []
# get all chunks for the document
chunks = _get_chunks_via_visit_api(
chunk_request=VespaChunkRequest(document_id=document_id),
index_name=index_name,
filters=IndexFilters(access_control_list=None),
field_names=[
"document_id",
"chunk_id",
"title",
"content",
"metadata",
"primary_owners",
"secondary_owners",
"source_type",
"kg_entities",
"kg_relationships",
"kg_terms",
],
get_large_chunks=False,
)
# Convert Vespa chunks to KGChunks
# kg_chunks: list[KGChunkFormat] = []
for i, chunk in enumerate(chunks):
fields = chunk["fields"]
if isinstance(fields.get("metadata", {}), str):
fields["metadata"] = json.loads(fields["metadata"])
current_batch.append(
KGChunkFormat(
connector_id=None, # We may need to adjust this
document_id=fields.get("document_id"),
chunk_id=fields.get("chunk_id"),
primary_owners=fields.get("primary_owners", []),
secondary_owners=fields.get("secondary_owners", []),
source_type=fields.get("source_type", ""),
title=fields.get("title", ""),
content=fields.get("content", ""),
metadata=fields.get("metadata", {}),
entities=fields.get("kg_entities", {}),
relationships=fields.get("kg_relationships", {}),
terms=fields.get("kg_terms", {}),
)
)
if len(current_batch) >= batch_size:
yield current_batch
current_batch = []
# Yield any remaining chunks
if current_batch:
yield current_batch
def _get_classification_content_from_chunks(
first_num_classification_chunks: list[dict],
) -> str:
"""
Creates a KGClassificationContent object from a list of chunks.
"""
source_type = first_num_classification_chunks[0]["fields"]["source_type"]
if source_type == "fireflies":
classification_content = get_classification_content_from_fireflies_chunks(
first_num_classification_chunks
)
else:
classification_content = (
first_num_classification_chunks[0]["fields"]["title"]
+ "\n"
+ "\n".join(
[
chunk_content["fields"]["content"]
for chunk_content in first_num_classification_chunks
]
)
)
return classification_content

View File

@@ -126,13 +126,16 @@ def get_default_llm_with_vision(
with get_session_with_current_tenant() as db_session:
# Try the default vision provider first
default_provider = fetch_default_vision_provider(db_session)
if default_provider and default_provider.default_vision_model:
if model_supports_image_input(
if (
default_provider
and default_provider.default_vision_model
and model_supports_image_input(
default_provider.default_vision_model, default_provider.provider
):
return create_vision_llm(
default_provider, default_provider.default_vision_model
)
)
):
return create_vision_llm(
default_provider, default_provider.default_vision_model
)
# Fall back to searching all providers
providers = fetch_existing_llm_providers(db_session)
@@ -140,36 +143,14 @@ def get_default_llm_with_vision(
if not providers:
return None
# Check all providers for viable vision models
# Find the first provider that supports image input
for provider in providers:
provider_view = LLMProviderView.from_model(provider)
# First priority: Check if provider has a default_vision_model
if provider.default_vision_model and model_supports_image_input(
provider.default_vision_model, provider.provider
):
return create_vision_llm(provider_view, provider.default_vision_model)
# If no model_names are specified, try default models in priority order
if not provider.model_names:
# Try default_model_name
if provider.default_model_name and model_supports_image_input(
provider.default_model_name, provider.provider
):
return create_vision_llm(provider_view, provider.default_model_name)
# Try fast_default_model_name
if provider.fast_default_model_name and model_supports_image_input(
provider.fast_default_model_name, provider.provider
):
return create_vision_llm(
provider_view, provider.fast_default_model_name
)
else:
# If model_names is specified, check each model
for model_name in provider.model_names:
if model_supports_image_input(model_name, provider.provider):
return create_vision_llm(provider_view, model_name)
return create_vision_llm(
LLMProviderView.from_model(provider), provider.default_vision_model
)
return None

View File

@@ -246,75 +246,3 @@ Please give a short succinct summary of the entire document. Answer only with th
summary and nothing else. """
DOCUMENT_SUMMARY_TOKEN_ESTIMATE = 29
QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT = """
Please rephrase the following user question/query as a semantic query that would be appropriate for a \
search engine.
Note:
- do not change the meaning of the question! Specifically, if the query is a an instruction, keep it \
as an instruction!
Here is the user question/query:
{question}
Respond with EXACTLY and ONLY one rephrased question/query.
Rephrased question/query for search engine:
""".strip()
QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT = """
Following a previous message history, a user created a follow-up question/query.
Please rephrase that question/query as a semantic query \
that would be appropriate for a SEARCH ENGINE. Only use the information provided \
from the history that is relevant to provide the relevant context for the search query, \
meaning that the rephrased search query should be a suitable stand-alone search query.
Note:
- do not change the meaning of the question! Specifically, if the query is a an instruction, keep it \
as an instruction!
Here is the relevant previous message history:
{history}
Here is the user question:
{question}
Respond with EXACTLY and ONLY one rephrased query.
Rephrased query for search engine:
""".strip()
QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT = """
Please rephrase the following user question as a keyword query that would be appropriate for a \
search engine.
Here is the user question:
{question}
Respond with EXACTLY and ONLY one rephrased query.
Rephrased query for search engine:
""".strip()
QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT = """
Following a previous message history, a user created a follow-up question/query.
Please rephrase that question/query as a keyword query \
that would be appropriate for a SEARCH ENGINE. Only use the information provided \
from the history that is relevant to provide the relevant context for the search query, \
meaning that the rephrased search query should be a suitable stand-alone search query.
Here is the relevant previous message history:
{history}
Here is the user question:
{question}
Respond with EXACTLY and ONLY one rephrased query.
Rephrased query for search engine:
""".strip()

File diff suppressed because it is too large Load Diff

View File

@@ -783,7 +783,6 @@ def get_connector_indexing_status(
name=cc_pair.name,
in_progress=in_progress,
cc_pair_status=cc_pair.status,
in_repeated_error_state=cc_pair.in_repeated_error_state,
connector=ConnectorSnapshot.from_connector_db_model(
connector, connector_to_cc_pair_ids.get(connector.id, [])
),

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import Generic
from typing import TypeVar
@@ -129,7 +128,6 @@ class CredentialBase(BaseModel):
class CredentialSnapshot(CredentialBase):
id: int
user_id: UUID | None
user_email: str | None = None
time_created: datetime
time_updated: datetime
@@ -143,7 +141,6 @@ class CredentialSnapshot(CredentialBase):
else credential.credential_json
),
user_id=credential.user_id,
user_email=credential.user.email if credential.user else None,
admin_public=credential.admin_public,
time_created=credential.time_created,
time_updated=credential.time_updated,
@@ -210,7 +207,6 @@ class CCPairFullInfo(BaseModel):
id: int
name: str
status: ConnectorCredentialPairStatus
in_repeated_error_state: bool
num_docs_indexed: int
connector: ConnectorSnapshot
credential: CredentialSnapshot
@@ -224,13 +220,6 @@ class CCPairFullInfo(BaseModel):
creator: UUID | None
creator_email: str | None
# information on syncing/indexing
last_indexed: datetime | None
last_pruned: datetime | None
last_permission_sync: datetime | None
overall_indexing_speed: float | None
latest_checkpoint_description: str | None
@classmethod
def from_models(
cls,
@@ -248,8 +237,7 @@ class CCPairFullInfo(BaseModel):
# there is a mismatch between these two numbers which may confuse users.
last_indexing_status = last_index_attempt.status if last_index_attempt else None
if (
# only need to do this if the last indexing attempt is still in progress
last_indexing_status == IndexingStatus.IN_PROGRESS
last_indexing_status == IndexingStatus.SUCCESS
and number_of_index_attempts == 1
and last_index_attempt
and last_index_attempt.new_docs_indexed
@@ -258,18 +246,10 @@ class CCPairFullInfo(BaseModel):
last_index_attempt.new_docs_indexed if last_index_attempt else 0
)
overall_indexing_speed = num_docs_indexed / (
(
datetime.now(tz=timezone.utc) - cc_pair_model.connector.time_created
).total_seconds()
/ 60
)
return cls(
id=cc_pair_model.id,
name=cc_pair_model.name,
status=cc_pair_model.status,
in_repeated_error_state=cc_pair_model.in_repeated_error_state,
num_docs_indexed=num_docs_indexed,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_model.connector
@@ -288,15 +268,6 @@ class CCPairFullInfo(BaseModel):
creator_email=(
cc_pair_model.creator.email if cc_pair_model.creator else None
),
last_indexed=(
last_index_attempt.time_started if last_index_attempt else None
),
last_pruned=cc_pair_model.last_pruned,
last_permission_sync=(
last_index_attempt.time_started if last_index_attempt else None
),
overall_indexing_speed=overall_indexing_speed,
latest_checkpoint_description=None,
)
@@ -337,9 +308,6 @@ class ConnectorIndexingStatus(ConnectorStatus):
"""Represents the full indexing status of a connector"""
cc_pair_status: ConnectorCredentialPairStatus
# this is separate from the `status` above, since a connector can be `INITIAL_INDEXING`, `ACTIVE`,
# or `PAUSED` and still be in a repeated error state.
in_repeated_error_state: bool
owner: str
last_finished_status: IndexingStatus | None
last_status: IndexingStatus | None

View File

@@ -118,13 +118,6 @@ class LLMProviderView(LLMProvider):
@classmethod
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
# Safely get groups - handle detached instance case
try:
groups = [group.id for group in llm_provider_model.groups]
except Exception:
# If groups relationship can't be loaded (detached instance), use empty list
groups = []
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
@@ -155,7 +148,7 @@ class LLMProviderView(LLMProvider):
else None
),
is_public=llm_provider_model.is_public,
groups=groups,
groups=[group.id for group in llm_provider_model.groups],
deployment_name=llm_provider_model.deployment_name,
)

View File

@@ -11,7 +11,6 @@ from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import QueryExpansions
from shared_configs.model_server_models import Embedding
@@ -75,17 +74,11 @@ class SearchToolOverrideKwargs(BaseModel):
precomputed_keywords: list[str] | None = None
user_file_ids: list[int] | None = None
user_folder_ids: list[int] | None = None
document_sources: list[DocumentSource] | None = None
time_cutoff: datetime | None = None
ordering_only: bool | None = (
None # Flag for fast path when search is only needed for ordering
)
document_sources: list[DocumentSource] | None = None
time_cutoff: datetime | None = None
expanded_queries: QueryExpansions | None = None
kg_entities: list[str] | None = None
kg_relationships: list[str] | None = None
kg_terms: list[str] | None = None
class Config:
arbitrary_types_allowed = True

View File

@@ -295,7 +295,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
ordering_only = False
document_sources = None
time_cutoff = None
expanded_queries = None
if override_kwargs:
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
alternate_db_session = override_kwargs.alternate_db_session
@@ -308,10 +307,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
ordering_only = use_alt_not_None(override_kwargs.ordering_only, False)
document_sources = override_kwargs.document_sources
time_cutoff = override_kwargs.time_cutoff
expanded_queries = override_kwargs.expanded_queries
kg_entities = override_kwargs.kg_entities
kg_relationships = override_kwargs.kg_relationships
kg_terms = override_kwargs.kg_terms
# Fast path for ordering-only search
if ordering_only:
@@ -361,16 +356,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
# Overwrite time-cutoff should supercede existing time-cutoff, even if defined
retrieval_options.filters.time_cutoff = time_cutoff
# Initialize kg filters in retrieval options and filters with all provided values
retrieval_options = retrieval_options or RetrievalDetails()
retrieval_options.filters = retrieval_options.filters or BaseFilters()
if kg_entities:
retrieval_options.filters.kg_entities = kg_entities
if kg_relationships:
retrieval_options.filters.kg_relationships = kg_relationships
if kg_terms:
retrieval_options.filters.kg_terms = kg_terms
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=query,
@@ -406,8 +391,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
precomputed_query_embedding=precomputed_query_embedding,
precomputed_is_keyword=precomputed_is_keyword,
precomputed_keywords=precomputed_keywords,
# add expanded queries
expanded_queries=expanded_queries,
),
user=self.user,
llm=self.llm,

View File

@@ -69,7 +69,7 @@ def get_log_level_from_str(log_level_str: str = LOG_LEVEL) -> int:
"NOTSET": logging.NOTSET,
}
return log_level_dict.get(log_level_str.upper(), logging.INFO)
return log_level_dict.get(log_level_str.upper(), logging.getLevelName("NOTICE"))
class OnyxRequestIDFilter(logging.Filter):

View File

@@ -80,7 +80,6 @@ slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.15
starlette==0.46.1
supervisor==4.2.5
thefuzz==0.22.1
tiktoken==0.7.0
timeago==1.0.16
transformers==4.49.0

View File

@@ -47,7 +47,6 @@ from onyx.context.search.models import IndexFilters
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import SqlEngine
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
@@ -515,7 +514,6 @@ def get_number_of_chunks_we_think_exist(
class VespaDebugging:
# Class for managing Vespa debugging actions.
def __init__(self, tenant_id: str = POSTGRES_DEFAULT_SCHEMA):
SqlEngine.init_engine(pool_size=20, max_overflow=5)
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
self.tenant_id = tenant_id
self.index_name = get_index_name(self.tenant_id)
@@ -857,7 +855,6 @@ def delete_documents_for_tenant(
def main() -> None:
SqlEngine.init_engine(pool_size=20, max_overflow=5)
parser = argparse.ArgumentParser(description="Vespa debugging tool")
parser.add_argument(
"--action",

View File

@@ -5,7 +5,6 @@ RUN THIS AFTER SEED_DUMMY_DOCS.PY
import random
import time
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
from onyx.configs.constants import DocumentSource
from onyx.configs.model_configs import DOC_EMBEDDING_DIM
from onyx.context.search.models import IndexFilters
@@ -97,7 +96,6 @@ def test_hybrid_retrieval_times(
hybrid_alpha=0.5,
time_decay_multiplier=1.0,
num_to_retrieve=50,
ranking_profile_type=QueryExpansionType.SEMANTIC,
offset=0,
title_content_ratio=0.5,
)

View File

@@ -64,7 +64,7 @@ LOG_FILE_NAME = os.environ.get("LOG_FILE_NAME") or "onyx"
# Enable generating persistent log files for local dev environments
DEV_LOGGING_ENABLED = os.environ.get("DEV_LOGGING_ENABLED", "").lower() == "true"
# notset, debug, info, notice, warning, error, or critical
LOG_LEVEL = os.environ.get("LOG_LEVEL") or "info"
LOG_LEVEL = os.environ.get("LOG_LEVEL", "info")
# Timeout for API-based embedding models
# NOTE: does not apply for Google VertexAI, since the python client doesn't

View File

@@ -69,7 +69,6 @@ class CCPairManager:
connector_specific_config: dict[str, Any] | None = None,
credential_json: dict[str, Any] | None = None,
user_performing_action: DATestUser | None = None,
refresh_freq: int | None = None,
) -> DATestCCPair:
connector = ConnectorManager.create(
name=name,
@@ -79,7 +78,6 @@ class CCPairManager:
access_type=access_type,
groups=groups,
user_performing_action=user_performing_action,
refresh_freq=refresh_freq,
)
credential = CredentialManager.create(
credential_json=credential_json,

View File

@@ -78,19 +78,15 @@ class ChatSessionManager:
use_existing_user_message=use_existing_user_message,
)
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
cookies = user_performing_action.cookies if user_performing_action else None
response = requests.post(
f"{API_SERVER_URL}/chat/send-message",
json=chat_message_req.model_dump(),
headers=headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
stream=True,
cookies=cookies,
)
return ChatSessionManager.analyze_response(response)

View File

@@ -23,7 +23,6 @@ class ConnectorManager:
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
refresh_freq: int | None = None,
) -> DATestConnector:
name = f"{name}-connector" if name else f"test-connector-{uuid4()}"
@@ -37,7 +36,6 @@ class ConnectorManager:
),
access_type=access_type,
groups=groups or [],
refresh_freq=refresh_freq,
)
response = requests.post(

View File

@@ -231,11 +231,6 @@ class DocumentManager:
for doc_dict in retrieved_docs_dict:
doc_id = doc_dict["fields"]["document_id"]
doc_content = doc_dict["fields"]["content"]
image_file_name = doc_dict["fields"].get("image_file_name", None)
final_docs.append(
SimpleTestDocument(
id=doc_id, content=doc_content, image_file_name=image_file_name
)
)
final_docs.append(SimpleTestDocument(id=doc_id, content=doc_content))
return final_docs

View File

@@ -1,4 +1,3 @@
import io
import mimetypes
from typing import cast
from typing import IO
@@ -63,43 +62,3 @@ class FileManager:
)
response.raise_for_status()
return response.content
@staticmethod
def upload_file_for_connector(
file_path: str, file_name: str, user_performing_action: DATestUser
) -> dict:
# Read the file content
with open(file_path, "rb") as f:
file_content = f.read()
# Create a file-like object
file_obj = io.BytesIO(file_content)
# The 'files' form field expects a list of files
files = [("files", (file_name, file_obj, "application/octet-stream"))]
# Use the user's headers but without Content-Type
# as requests will set the correct multipart/form-data Content-Type for us
headers = user_performing_action.headers.copy()
if "Content-Type" in headers:
del headers["Content-Type"]
# Make the request
response = requests.post(
f"{API_SERVER_URL}/manage/admin/connector/file/upload",
files=files,
headers=headers,
)
if not response.ok:
try:
error_detail = response.json().get("detail", "Unknown error")
except Exception:
error_detail = response.text
raise Exception(
f"Unable to upload files - {error_detail} (Status code: {response.status_code})"
)
response_json = response.json()
return response_json

View File

@@ -23,7 +23,7 @@ class SettingsManager:
headers.pop("Content-Type", None)
response = requests.get(
f"{API_SERVER_URL}/admin/settings",
f"{API_SERVER_URL}/api/manage/admin/settings",
headers=headers,
)
@@ -48,8 +48,8 @@ class SettingsManager:
headers.pop("Content-Type", None)
payload = settings.model_dump()
response = requests.put(
f"{API_SERVER_URL}/admin/settings",
response = requests.patch(
f"{API_SERVER_URL}/api/manage/admin/settings",
json=payload,
headers=headers,
)

View File

@@ -76,7 +76,6 @@ class DATestConnector(BaseModel):
class SimpleTestDocument(BaseModel):
id: str
content: str
image_file_name: str | None = None
class DATestCCPair(BaseModel):
@@ -178,8 +177,6 @@ class DATestSettings(BaseModel):
gpu_enabled: bool | None = None
product_gating: DATestGatingType = DATestGatingType.NONE
anonymous_user_enabled: bool | None = None
image_extraction_and_analysis_enabled: bool | None = False
search_time_image_analysis_enabled: bool | None = False
@dataclass

View File

@@ -6,8 +6,7 @@ INVITED_BASIC_USER = "basic_user"
INVITED_BASIC_USER_EMAIL = "basic_user@test.com"
def test_admin_can_invite_users(reset_multitenant: None) -> None:
"""Test that an admin can invite both registered and non-registered users."""
def test_user_invitation_flow(reset_multitenant: None) -> None:
# Create first user (admin)
admin_user: DATestUser = UserManager.create(name="admin")
assert UserManager.is_role(admin_user, UserRole.ADMIN)
@@ -20,44 +19,16 @@ def test_admin_can_invite_users(reset_multitenant: None) -> None:
UserManager.invite_user(invited_user.email, admin_user)
UserManager.invite_user(INVITED_BASIC_USER_EMAIL, admin_user)
# Verify users are in the invited users list
invited_users = UserManager.get_invited_users(admin_user)
assert invited_user.email in [
user.email for user in invited_users
], f"User {invited_user.email} not found in invited users list"
def test_non_registered_user_gets_basic_role(reset_multitenant: None) -> None:
"""Test that a non-registered user gets a BASIC role when they register after being invited."""
# Create admin user
admin_user: DATestUser = UserManager.create(name="admin")
assert UserManager.is_role(admin_user, UserRole.ADMIN)
# Admin user invites a non-registered user
UserManager.invite_user(INVITED_BASIC_USER_EMAIL, admin_user)
# Non-registered user registers
invited_basic_user: DATestUser = UserManager.create(
name=INVITED_BASIC_USER, email=INVITED_BASIC_USER_EMAIL
)
assert UserManager.is_role(invited_basic_user, UserRole.BASIC)
def test_user_can_accept_invitation(reset_multitenant: None) -> None:
"""Test that a user can accept an invitation and join the organization with BASIC role."""
# Create admin user
admin_user: DATestUser = UserManager.create(name="admin")
assert UserManager.is_role(admin_user, UserRole.ADMIN)
# Create a user to be invited
invited_user_email = "invited_user@test.com"
# User registers with the same email as the invitation
invited_user: DATestUser = UserManager.create(
name="invited_user", email=invited_user_email
)
# Admin user invites the user
UserManager.invite_user(invited_user_email, admin_user)
# Verify the user is in the invited users list
invited_users = UserManager.get_invited_users(admin_user)
assert invited_user.email in [
user.email for user in invited_users
], f"User {invited_user.email} not found in invited users list"
# Get user info to check tenant information
user_info = UserManager.get_user_info(invited_user)
@@ -70,17 +41,16 @@ def test_user_can_accept_invitation(reset_multitenant: None) -> None:
)
assert invited_tenant_id is not None, "Expected to find an invitation tenant_id"
# User accepts invitation
UserManager.accept_invitation(invited_tenant_id, invited_user)
# User needs to reauthenticate after accepting invitation
# Simulate this by creating a new user instance with the same credentials
authenticated_user: DATestUser = UserManager.create(
name="invited_user", email=invited_user_email
)
# Get updated user info after accepting invitation
updated_user_info = UserManager.get_user_info(invited_user)
# Get updated user info after accepting invitation and reauthenticating
updated_user_info = UserManager.get_user_info(authenticated_user)
# Verify the user is no longer in the invited users list
updated_invited_users = UserManager.get_invited_users(admin_user)
assert invited_user.email not in [
user.email for user in updated_invited_users
], f"User {invited_user.email} should not be in invited users list after accepting"
# Verify the user has BASIC role in the organization
assert (
@@ -94,7 +64,7 @@ def test_user_can_accept_invitation(reset_multitenant: None) -> None:
# Check if the invited user is in the list of users with BASIC role
invited_user_emails = [user.email for user in user_page.items]
assert invited_user_email in invited_user_emails, (
f"User {invited_user_email} not found in the list of basic users "
assert invited_user.email in invited_user_emails, (
f"User {invited_user.email} not found in the list of basic users "
f"in the organization. Available users: {invited_user_emails}"
)

View File

@@ -1,5 +1,3 @@
from typing import Any
from onyx.db.models import UserRole
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
@@ -13,12 +11,12 @@ from tests.integration.common_utils.test_models import DATestChatSession
from tests.integration.common_utils.test_models import DATestUser
def setup_test_tenants(reset_multitenant: None) -> dict[str, Any]:
"""Helper function to set up test tenants with documents and users."""
# Creating an admin user for Tenant 1
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
# Creating an admin user (first user created is automatically an admin and also proviions the tenant
admin_user1: DATestUser = UserManager.create(
email="admin@onyx-test.com",
)
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
# Create Tenant 2 and its Admin User
@@ -37,16 +35,6 @@ def setup_test_tenants(reset_multitenant: None) -> dict[str, Any]:
api_key_1.headers.update(admin_user1.headers)
LLMProviderManager.create(user_performing_action=admin_user1)
# Create connectors for Tenant 2
cc_pair_2: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user2,
)
api_key_2: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user2,
)
api_key_2.headers.update(admin_user2.headers)
LLMProviderManager.create(user_performing_action=admin_user2)
# Seed documents for Tenant 1
cc_pair_1.documents = []
doc1_tenant1 = DocumentManager.seed_doc_with_content(
@@ -61,6 +49,16 @@ def setup_test_tenants(reset_multitenant: None) -> dict[str, Any]:
)
cc_pair_1.documents.extend([doc1_tenant1, doc2_tenant1])
# Create connectors for Tenant 2
cc_pair_2: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user2,
)
api_key_2: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user2,
)
api_key_2.headers.update(admin_user2.headers)
LLMProviderManager.create(user_performing_action=admin_user2)
# Seed documents for Tenant 2
cc_pair_2.documents = []
doc1_tenant2 = DocumentManager.seed_doc_with_content(
@@ -86,36 +84,21 @@ def setup_test_tenants(reset_multitenant: None) -> dict[str, Any]:
user_performing_action=admin_user2
)
return {
"admin_user1": admin_user1,
"admin_user2": admin_user2,
"chat_session1": chat_session1,
"chat_session2": chat_session2,
"tenant1_doc_ids": tenant1_doc_ids,
"tenant2_doc_ids": tenant2_doc_ids,
}
def test_tenant1_can_access_own_documents(reset_multitenant: None) -> None:
"""Test that Tenant 1 can access its own documents but not Tenant 2's."""
test_data = setup_test_tenants(reset_multitenant)
# User 1 sends a message and gets a response
response1 = ChatSessionManager.send_message(
chat_session_id=test_data["chat_session1"].id,
chat_session_id=chat_session1.id,
message="What is in Tenant 1's documents?",
user_performing_action=test_data["admin_user1"],
user_performing_action=admin_user1,
)
# Assert that the search tool was used
assert response1.tool_name == "run_search"
response_doc_ids = {doc["document_id"] for doc in response1.tool_result or []}
assert test_data["tenant1_doc_ids"].issubset(
assert tenant1_doc_ids.issubset(
response_doc_ids
), "Not all Tenant 1 document IDs are in the response"
assert not response_doc_ids.intersection(
test_data["tenant2_doc_ids"]
tenant2_doc_ids
), "Tenant 2 document IDs should not be in the response"
# Assert that the contents are correct
@@ -124,28 +107,21 @@ def test_tenant1_can_access_own_documents(reset_multitenant: None) -> None:
for doc in response1.tool_result or []
), "Tenant 1 Document Content not found in any document"
def test_tenant2_can_access_own_documents(reset_multitenant: None) -> None:
"""Test that Tenant 2 can access its own documents but not Tenant 1's."""
test_data = setup_test_tenants(reset_multitenant)
# User 2 sends a message and gets a response
response2 = ChatSessionManager.send_message(
chat_session_id=test_data["chat_session2"].id,
chat_session_id=chat_session2.id,
message="What is in Tenant 2's documents?",
user_performing_action=test_data["admin_user2"],
user_performing_action=admin_user2,
)
# Assert that the search tool was used
assert response2.tool_name == "run_search"
# Assert that the tool_result contains Tenant 2's documents
response_doc_ids = {doc["document_id"] for doc in response2.tool_result or []}
assert test_data["tenant2_doc_ids"].issubset(
assert tenant2_doc_ids.issubset(
response_doc_ids
), "Not all Tenant 2 document IDs are in the response"
assert not response_doc_ids.intersection(
test_data["tenant1_doc_ids"]
tenant1_doc_ids
), "Tenant 1 document IDs should not be in the response"
# Assert that the contents are correct
@@ -154,91 +130,28 @@ def test_tenant2_can_access_own_documents(reset_multitenant: None) -> None:
for doc in response2.tool_result or []
), "Tenant 2 Document Content not found in any document"
def test_tenant1_cannot_access_tenant2_documents(reset_multitenant: None) -> None:
"""Test that Tenant 1 cannot access Tenant 2's documents."""
test_data = setup_test_tenants(reset_multitenant)
# User 1 tries to access Tenant 2's documents
response_cross = ChatSessionManager.send_message(
chat_session_id=test_data["chat_session1"].id,
chat_session_id=chat_session1.id,
message="What is in Tenant 2's documents?",
user_performing_action=test_data["admin_user1"],
user_performing_action=admin_user1,
)
# Assert that the search tool was used
assert response_cross.tool_name == "run_search"
# Assert that the tool_result is empty or does not contain Tenant 2's documents
response_doc_ids = {doc["document_id"] for doc in response_cross.tool_result or []}
# Ensure none of Tenant 2's document IDs are in the response
assert not response_doc_ids.intersection(test_data["tenant2_doc_ids"])
def test_tenant2_cannot_access_tenant1_documents(reset_multitenant: None) -> None:
"""Test that Tenant 2 cannot access Tenant 1's documents."""
test_data = setup_test_tenants(reset_multitenant)
assert not response_doc_ids.intersection(tenant2_doc_ids)
# User 2 tries to access Tenant 1's documents
response_cross2 = ChatSessionManager.send_message(
chat_session_id=test_data["chat_session2"].id,
chat_session_id=chat_session2.id,
message="What is in Tenant 1's documents?",
user_performing_action=test_data["admin_user2"],
user_performing_action=admin_user2,
)
# Assert that the search tool was used
assert response_cross2.tool_name == "run_search"
# Assert that the tool_result is empty or does not contain Tenant 1's documents
response_doc_ids = {doc["document_id"] for doc in response_cross2.tool_result or []}
# Ensure none of Tenant 1's document IDs are in the response
assert not response_doc_ids.intersection(test_data["tenant1_doc_ids"])
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
"""Legacy test for multi-tenant access control."""
test_data = setup_test_tenants(reset_multitenant)
# User 1 sends a message and gets a response with only Tenant 1's documents
response1 = ChatSessionManager.send_message(
chat_session_id=test_data["chat_session1"].id,
message="What is in Tenant 1's documents?",
user_performing_action=test_data["admin_user1"],
)
assert response1.tool_name == "run_search"
response_doc_ids = {doc["document_id"] for doc in response1.tool_result or []}
assert test_data["tenant1_doc_ids"].issubset(response_doc_ids)
assert not response_doc_ids.intersection(test_data["tenant2_doc_ids"])
# User 2 sends a message and gets a response with only Tenant 2's documents
response2 = ChatSessionManager.send_message(
chat_session_id=test_data["chat_session2"].id,
message="What is in Tenant 2's documents?",
user_performing_action=test_data["admin_user2"],
)
assert response2.tool_name == "run_search"
response_doc_ids = {doc["document_id"] for doc in response2.tool_result or []}
assert test_data["tenant2_doc_ids"].issubset(response_doc_ids)
assert not response_doc_ids.intersection(test_data["tenant1_doc_ids"])
# User 1 tries to access Tenant 2's documents and fails
response_cross = ChatSessionManager.send_message(
chat_session_id=test_data["chat_session1"].id,
message="What is in Tenant 2's documents?",
user_performing_action=test_data["admin_user1"],
)
assert response_cross.tool_name == "run_search"
response_doc_ids = {doc["document_id"] for doc in response_cross.tool_result or []}
assert not response_doc_ids.intersection(test_data["tenant2_doc_ids"])
# User 2 tries to access Tenant 1's documents and fails
response_cross2 = ChatSessionManager.send_message(
chat_session_id=test_data["chat_session2"].id,
message="What is in Tenant 1's documents?",
user_performing_action=test_data["admin_user2"],
)
assert response_cross2.tool_name == "run_search"
response_doc_ids = {doc["document_id"] for doc in response_cross2.tool_result or []}
assert not response_doc_ids.intersection(test_data["tenant1_doc_ids"])
assert not response_doc_ids.intersection(tenant1_doc_ids)

View File

@@ -8,51 +8,12 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
def test_first_user_is_admin(reset_multitenant: None) -> None:
"""Test that the first user of a tenant is automatically assigned ADMIN role."""
# Test flow from creating tenant to registering as a user
def test_tenant_creation(reset_multitenant: None) -> None:
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
assert UserManager.is_role(test_user, UserRole.ADMIN)
def test_admin_can_create_credential(reset_multitenant: None) -> None:
"""Test that an admin user can create a credential in their tenant."""
# Create admin user
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
assert UserManager.is_role(test_user, UserRole.ADMIN)
# Create credential
test_credential = CredentialManager.create(
name="admin_test_credential",
source=DocumentSource.FILE,
curator_public=False,
user_performing_action=test_user,
)
assert test_credential is not None
def test_admin_can_create_connector(reset_multitenant: None) -> None:
"""Test that an admin user can create a connector in their tenant."""
# Create admin user
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
assert UserManager.is_role(test_user, UserRole.ADMIN)
# Create connector
test_connector = ConnectorManager.create(
name="admin_test_connector",
source=DocumentSource.FILE,
access_type=AccessType.PRIVATE,
user_performing_action=test_user,
)
assert test_connector is not None
def test_admin_can_create_and_verify_cc_pair(reset_multitenant: None) -> None:
"""Test that an admin user can create and verify a connector-credential pair in their tenant."""
# Create admin user
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
assert UserManager.is_role(test_user, UserRole.ADMIN)
# Create credential
test_credential = CredentialManager.create(
name="admin_test_credential",
source=DocumentSource.FILE,
@@ -60,7 +21,6 @@ def test_admin_can_create_and_verify_cc_pair(reset_multitenant: None) -> None:
user_performing_action=test_user,
)
# Create connector
test_connector = ConnectorManager.create(
name="admin_test_connector",
source=DocumentSource.FILE,
@@ -68,7 +28,6 @@ def test_admin_can_create_and_verify_cc_pair(reset_multitenant: None) -> None:
user_performing_action=test_user,
)
# Create cc_pair
test_cc_pair = CCPairManager.create(
connector_id=test_connector.id,
credential_id=test_credential.id,
@@ -76,7 +35,5 @@ def test_admin_can_create_and_verify_cc_pair(reset_multitenant: None) -> None:
access_type=AccessType.PRIVATE,
user_performing_action=test_user,
)
assert test_cc_pair is not None
# Verify cc_pair
CCPairManager.verify(cc_pair=test_cc_pair, user_performing_action=test_user)

View File

@@ -1,117 +0,0 @@
import os
from datetime import datetime
from datetime import timezone
import pytest
from onyx.connectors.models import InputType
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.file import FileManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.settings import SettingsManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestSettings
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
FILE_NAME = "Sample.pdf"
FILE_PATH = "tests/integration/common_utils/test_files"
def test_image_indexing(
reset: None,
vespa_client: vespa_fixture,
) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(
email="admin@onyx-test.com",
)
os.makedirs(FILE_PATH, exist_ok=True)
test_file_path = os.path.join(FILE_PATH, FILE_NAME)
# Use FileManager to upload the test file
upload_response = FileManager.upload_file_for_connector(
file_path=test_file_path, file_name=FILE_NAME, user_performing_action=admin_user
)
LLMProviderManager.create(
name="test_llm",
user_performing_action=admin_user,
)
SettingsManager.update_settings(
DATestSettings(
search_time_image_analysis_enabled=True,
image_extraction_and_analysis_enabled=True,
),
user_performing_action=admin_user,
)
file_paths = upload_response.get("file_paths", [])
if not file_paths:
pytest.fail("File upload failed - no file paths returned")
# Create a dummy credential for the file connector
credential = CredentialManager.create(
source=DocumentSource.FILE,
credential_json={},
user_performing_action=admin_user,
)
# Create the connector
connector_name = f"FileConnector-{int(datetime.now().timestamp())}"
connector = ConnectorManager.create(
name=connector_name,
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={"file_locations": file_paths},
access_type=AccessType.PUBLIC,
groups=[],
user_performing_action=admin_user,
)
# Link the credential to the connector
cc_pair = CCPairManager.create(
credential_id=credential.id,
connector_id=connector.id,
access_type=AccessType.PUBLIC,
user_performing_action=admin_user,
)
# Explicitly run the connector to start indexing
CCPairManager.run_once(
cc_pair=cc_pair,
from_beginning=True,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=datetime.now(timezone.utc),
user_performing_action=admin_user,
)
with get_session_context_manager() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
# Ensure we indexed an image from the sample.pdf file
has_sample_pdf_image = False
for doc in documents:
if doc.image_file_name and FILE_NAME in doc.image_file_name:
has_sample_pdf_image = True
# Assert that at least one document has an image file name containing "sample.pdf"
assert (
has_sample_pdf_image
), "No document found with an image file name containing 'sample.pdf'"

View File

@@ -1,18 +0,0 @@
import httpx
import pytest
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_HOST
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_PORT
@pytest.fixture
def mock_server_client() -> httpx.Client:
print(
f"Initializing mock server client with host: "
f"{MOCK_CONNECTOR_SERVER_HOST} and port: "
f"{MOCK_CONNECTOR_SERVER_PORT}"
)
return httpx.Client(
base_url=f"http://{MOCK_CONNECTOR_SERVER_HOST}:{MOCK_CONNECTOR_SERVER_PORT}",
timeout=5.0,
)

View File

@@ -4,6 +4,7 @@ from datetime import timedelta
from datetime import timezone
import httpx
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.mock_connector.connector import MockConnectorCheckpoint
@@ -25,6 +26,19 @@ from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
@pytest.fixture
def mock_server_client() -> httpx.Client:
print(
f"Initializing mock server client with host: "
f"{MOCK_CONNECTOR_SERVER_HOST} and port: "
f"{MOCK_CONNECTOR_SERVER_PORT}"
)
return httpx.Client(
base_url=f"http://{MOCK_CONNECTOR_SERVER_HOST}:{MOCK_CONNECTOR_SERVER_PORT}",
timeout=5.0,
)
def test_mock_connector_basic_flow(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture,

View File

@@ -1,204 +0,0 @@
import time
import uuid
import httpx
from onyx.background.celery.tasks.indexing.utils import (
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE,
)
from onyx.configs.constants import DocumentSource
from onyx.connectors.mock_connector.connector import MockConnectorCheckpoint
from onyx.connectors.models import InputType
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import IndexingStatus
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_HOST
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_PORT
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
from tests.integration.common_utils.test_document_utils import create_test_document
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
def test_repeated_error_state_detection_and_recovery(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture,
admin_user: DATestUser,
) -> None:
"""Test that a connector is marked as in a repeated error state after
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE consecutive failures, and
that it recovers after a successful indexing.
This test ensures we properly wait for the required number of indexing attempts
to fail before checking that the connector is in a repeated error state."""
# Create test document for successful response
test_doc = create_test_document()
# First, set up the mock server to consistently fail
error_response = {
"documents": [],
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(mode="json"),
"failures": [],
"unhandled_exception": "Simulated unhandled error for testing repeated errors",
}
# Create a list of failure responses with at least the same length
# as NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE
failure_behaviors = [error_response] * (
5 * NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE
)
response = mock_server_client.post(
"/set-behavior",
json=failure_behaviors,
)
assert response.status_code == 200
# Create a new CC pair for testing
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-repeated-error-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
user_performing_action=admin_user,
refresh_freq=60 * 60, # a very long time
)
# Wait for the required number of failed indexing attempts
# This shouldn't take long, since we keep retrying while we haven't
# succeeded yet
start_time = time.monotonic()
while True:
index_attempts_page = IndexAttemptManager.get_index_attempt_page(
cc_pair_id=cc_pair.id,
page=0,
page_size=100,
user_performing_action=admin_user,
)
index_attempts = [
ia
for ia in index_attempts_page.items
if ia.status and ia.status.is_terminal()
]
if len(index_attempts) == NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE:
break
if time.monotonic() - start_time > 180:
raise TimeoutError(
"Did not get required number of failed attempts within 180 seconds"
)
# make sure that we don't mark the connector as in repeated error state
# before we have the required number of failed attempts
with get_session_context_manager() as db_session:
cc_pair_obj = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair.id,
)
assert cc_pair_obj is not None
assert not cc_pair_obj.in_repeated_error_state
time.sleep(2)
# Verify we have the correct number of failed attempts
assert len(index_attempts) == NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE
for attempt in index_attempts:
assert attempt.status == IndexingStatus.FAILED
# Check if the connector is in a repeated error state
start_time = time.monotonic()
while True:
with get_session_context_manager() as db_session:
cc_pair_obj = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair.id,
)
assert cc_pair_obj is not None
if cc_pair_obj.in_repeated_error_state:
break
if time.monotonic() - start_time > 30:
assert False, "CC pair did not enter repeated error state within 30 seconds"
time.sleep(2)
# Reset the mock server state
response = mock_server_client.post("/reset")
assert response.status_code == 200
# Now set up the mock server to succeed
success_response = {
"documents": [test_doc.model_dump(mode="json")],
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(mode="json"),
"failures": [],
}
response = mock_server_client.post(
"/set-behavior",
json=[success_response],
)
assert response.status_code == 200
# Run another indexing attempt that should succeed
CCPairManager.run_once(
cc_pair, from_beginning=True, user_performing_action=admin_user
)
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
index_attempts_to_ignore=[index_attempt.id for index_attempt in index_attempts],
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=recovery_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# Validate the indexing succeeded
finished_recovery_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=recovery_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished_recovery_attempt.status == IndexingStatus.SUCCESS
# Verify the document was indexed
with get_session_context_manager() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 1
assert documents[0].id == test_doc.id
# Verify the CC pair is no longer in a repeated error state
start = time.monotonic()
while True:
with get_session_context_manager() as db_session:
cc_pair_obj = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair.id,
)
assert cc_pair_obj is not None
if not cc_pair_obj.in_repeated_error_state:
break
elapsed = time.monotonic() - start
if elapsed > 30:
raise TimeoutError(
"CC pair did not exit repeated error state within 30 seconds"
)
print(
f"Waiting for CC pair to exit repeated error state. elapsed={elapsed:.2f}"
)
time.sleep(1)

Some files were not shown because too many files have changed in this diff Show More