mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-19 08:45:47 +00:00
Compare commits
2 Commits
KG-prototy
...
update
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bbf5fa13dc | ||
|
|
a4a399bc31 |
@@ -1,219 +0,0 @@
|
||||
"""create knowlege graph tables
|
||||
|
||||
Revision ID: 495cb26ce93e
|
||||
Revises: 6a804aeb4830
|
||||
Create Date: 2025-03-19 08:51:14.341989
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "495cb26ce93e"
|
||||
down_revision = "6a804aeb4830"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"kg_entity_type",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column("grounding", sa.String(), nullable=False),
|
||||
sa.Column("clustering", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column(
|
||||
"classification_requirements",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("cluster_count", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"extraction_sources", postgresql.JSONB, nullable=False, server_default="{}"
|
||||
),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.Column("grounded_source_name", sa.String(), nullable=True, unique=True),
|
||||
sa.Column(
|
||||
"ge_determine_instructions", postgresql.ARRAY(sa.String()), nullable=True
|
||||
),
|
||||
sa.Column("ge_grounding_signature", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
# Create KGRelationshipType table
|
||||
op.create_table(
|
||||
"kg_relationship_type",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column(
|
||||
"source_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column(
|
||||
"target_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column("clustering", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("cluster_count", sa.Integer(), nullable=True),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
)
|
||||
|
||||
# Create KGEntity table
|
||||
op.create_table(
|
||||
"kg_entity",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("document_id", sa.String(), nullable=True, index=True),
|
||||
sa.Column(
|
||||
"alternative_names",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"keywords",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("cluster_count", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
|
||||
),
|
||||
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
|
||||
)
|
||||
op.create_index("ix_entity_type_acl", "kg_entity", ["entity_type_id_name", "acl"])
|
||||
op.create_index(
|
||||
"ix_entity_name_search", "kg_entity", ["name", "entity_type_id_name"]
|
||||
)
|
||||
|
||||
# Create KGRelationship table
|
||||
op.create_table(
|
||||
"kg_relationship",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("source_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("cluster_count", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["source_node"], ["kg_entity.id_name"]),
|
||||
sa.ForeignKeyConstraint(["target_node"], ["kg_entity.id_name"]),
|
||||
sa.ForeignKeyConstraint(
|
||||
["relationship_type_id_name"], ["kg_relationship_type.id_name"]
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_source_target_type",
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kg_relationship_nodes", "kg_relationship", ["source_node", "target_node"]
|
||||
)
|
||||
|
||||
# Create KGTerm table
|
||||
op.create_table(
|
||||
"kg_term",
|
||||
sa.Column("id_term", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column(
|
||||
"entity_types",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
)
|
||||
op.create_index("ix_search_term_entities", "kg_term", ["entity_types"])
|
||||
op.create_index("ix_search_term_term", "kg_term", ["id_term"])
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("kg_processed", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("kg_data", postgresql.JSONB(), nullable=False, server_default="{}"),
|
||||
)
|
||||
op.add_column(
|
||||
"connector",
|
||||
sa.Column(
|
||||
"kg_extraction_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"document_by_connector_credential_pair",
|
||||
sa.Column("has_been_kg_processed", sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop tables in reverse order of creation to handle dependencies
|
||||
op.drop_table("kg_term")
|
||||
op.drop_table("kg_relationship")
|
||||
op.drop_table("kg_entity")
|
||||
op.drop_table("kg_relationship_type")
|
||||
op.drop_table("kg_entity_type")
|
||||
op.drop_column("connector", "kg_extraction_enabled")
|
||||
op.drop_column("document_by_connector_credential_pair", "has_been_kg_processed")
|
||||
op.drop_column("document", "kg_data")
|
||||
op.drop_column("document", "kg_processed")
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Update status length
|
||||
|
||||
Revision ID: d961aca62eb3
|
||||
Revises: cf90764725d8
|
||||
Create Date: 2025-03-23 16:10:05.683965
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d961aca62eb3"
|
||||
down_revision = "cf90764725d8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing enum type constraint
|
||||
op.execute("ALTER TABLE connector_credential_pair ALTER COLUMN status TYPE varchar")
|
||||
|
||||
# Create new enum type with all values
|
||||
op.execute(
|
||||
"ALTER TABLE connector_credential_pair ALTER COLUMN status TYPE VARCHAR(20) USING status::varchar(20)"
|
||||
)
|
||||
|
||||
# Update the enum type to include all possible values
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"status",
|
||||
type_=sa.Enum(
|
||||
"SCHEDULED",
|
||||
"INITIAL_INDEXING",
|
||||
"ACTIVE",
|
||||
"PAUSED",
|
||||
"DELETING",
|
||||
"INVALID",
|
||||
name="connectorcredentialpairstatus",
|
||||
native_enum=False,
|
||||
),
|
||||
existing_type=sa.String(20),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"in_repeated_error_state", sa.Boolean, default=False, server_default="false"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# no need to convert back to the old enum type, since we're not using it anymore
|
||||
op.drop_column("connector_credential_pair", "in_repeated_error_state")
|
||||
@@ -140,7 +140,7 @@ def fetch_onyxbot_analytics(
|
||||
(
|
||||
or_(
|
||||
ChatMessageFeedback.is_positive.is_(False),
|
||||
ChatMessageFeedback.required_followup.is_(True),
|
||||
ChatMessageFeedback.required_followup,
|
||||
),
|
||||
1,
|
||||
),
|
||||
@@ -173,7 +173,7 @@ def fetch_onyxbot_analytics(
|
||||
.all()
|
||||
)
|
||||
|
||||
return [tuple(row) for row in results]
|
||||
return results
|
||||
|
||||
|
||||
def fetch_persona_message_analytics(
|
||||
|
||||
@@ -406,6 +406,7 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
headers=headers,
|
||||
json=payload.model_dump(),
|
||||
) as response:
|
||||
print(response)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Control plane tenant creation failed: {error_text}")
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -41,10 +39,10 @@ logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/custom")
|
||||
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: PreTrainedTokenizer | None = None
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: AutoTokenizer | None = None
|
||||
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
|
||||
_INTENT_TOKENIZER: PreTrainedTokenizer | None = None
|
||||
_INTENT_TOKENIZER: AutoTokenizer | None = None
|
||||
_INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
||||
@@ -52,14 +50,13 @@ _INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> PreTrainedTokenizer:
|
||||
def get_connector_classifier_tokenizer() -> AutoTokenizer:
|
||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# unmodified distilbert tokenizer.
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER = cast(
|
||||
PreTrainedTokenizer,
|
||||
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER = AutoTokenizer.from_pretrained(
|
||||
"distilbert-base-uncased"
|
||||
)
|
||||
return _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
|
||||
@@ -95,15 +92,12 @@ def get_local_connector_classifier(
|
||||
return _CONNECTOR_CLASSIFIER_MODEL
|
||||
|
||||
|
||||
def get_intent_model_tokenizer() -> PreTrainedTokenizer:
|
||||
def get_intent_model_tokenizer() -> AutoTokenizer:
|
||||
global _INTENT_TOKENIZER
|
||||
if _INTENT_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# unmodified distilbert tokenizer.
|
||||
_INTENT_TOKENIZER = cast(
|
||||
PreTrainedTokenizer,
|
||||
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||
)
|
||||
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
return _INTENT_TOKENIZER
|
||||
|
||||
|
||||
@@ -401,9 +395,9 @@ def run_content_classification_inference(
|
||||
|
||||
|
||||
def map_keywords(
|
||||
input_ids: torch.Tensor, tokenizer: PreTrainedTokenizer, is_keyword: list[bool]
|
||||
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
|
||||
) -> list[str]:
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
if not len(tokens) == len(is_keyword):
|
||||
raise ValueError("Length of tokens and keyword predictions must match")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -14,14 +13,15 @@ class HybridClassifier(nn.Module):
|
||||
super().__init__()
|
||||
config = DistilBertConfig()
|
||||
self.distilbert = DistilBertModel(config)
|
||||
config = self.distilbert.config # type: ignore
|
||||
|
||||
# Keyword tokenwise binary classification layer
|
||||
self.keyword_classifier = nn.Linear(config.dim, 2)
|
||||
self.keyword_classifier = nn.Linear(self.distilbert.config.dim, 2)
|
||||
|
||||
# Intent Classifier layers
|
||||
self.pre_classifier = nn.Linear(config.dim, config.dim)
|
||||
self.intent_classifier = nn.Linear(config.dim, 2)
|
||||
self.pre_classifier = nn.Linear(
|
||||
self.distilbert.config.dim, self.distilbert.config.dim
|
||||
)
|
||||
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
|
||||
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
@@ -30,7 +30,7 @@ class HybridClassifier(nn.Module):
|
||||
query_ids: torch.Tensor,
|
||||
query_mask: torch.Tensor,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # type: ignore
|
||||
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask)
|
||||
sequence_output = outputs.last_hidden_state
|
||||
|
||||
# Intent classification on the CLS token
|
||||
@@ -79,9 +79,8 @@ class ConnectorClassifier(nn.Module):
|
||||
|
||||
self.config = config
|
||||
self.distilbert = DistilBertModel(config)
|
||||
config = self.distilbert.config # type: ignore
|
||||
self.connector_global_classifier = nn.Linear(config.dim, 1)
|
||||
self.connector_match_classifier = nn.Linear(config.dim, 1)
|
||||
self.connector_global_classifier = nn.Linear(self.distilbert.config.dim, 1)
|
||||
self.connector_match_classifier = nn.Linear(self.distilbert.config.dim, 1)
|
||||
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
|
||||
# Token indicating end of connector name, and on which classifier is used
|
||||
@@ -96,7 +95,7 @@ class ConnectorClassifier(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states = self.distilbert( # type: ignore
|
||||
hidden_states = self.distilbert(
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).last_hidden_state
|
||||
|
||||
@@ -115,10 +114,7 @@ class ConnectorClassifier(nn.Module):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
|
||||
config = cast(
|
||||
DistilBertConfig,
|
||||
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
|
||||
)
|
||||
config = DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json"))
|
||||
device = (
|
||||
torch.device("cuda")
|
||||
if torch.cuda.is_available()
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class KGQuestionEntityExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
terms: list[str]
|
||||
time_filter: str | None
|
||||
|
||||
|
||||
class KGQuestionRelationshipExtractionResult(BaseModel):
|
||||
relationships: list[str]
|
||||
|
||||
|
||||
class KGQuestionExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
terms: list[str]
|
||||
time_filter: str | None
|
||||
@@ -37,7 +37,7 @@ def research_object_source(
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
graph_config.inputs.search_request.query
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
graph_config.inputs.search_request.query
|
||||
question = graph_config.inputs.search_request.query
|
||||
object, document_source = state.object_source_combination
|
||||
|
||||
if search_tool is None or graph_config.inputs.search_request.persona is None:
|
||||
|
||||
@@ -17,9 +17,6 @@ def research(
|
||||
search_tool: SearchTool,
|
||||
document_sources: list[DocumentSource] | None = None,
|
||||
time_cutoff: datetime | None = None,
|
||||
kg_entities: list[str] | None = None,
|
||||
kg_relationships: list[str] | None = None,
|
||||
kg_terms: list[str] | None = None,
|
||||
) -> list[LlmDoc]:
|
||||
# new db session to avoid concurrency issues
|
||||
|
||||
@@ -36,9 +33,6 @@ def research(
|
||||
skip_query_analysis=True,
|
||||
document_sources=document_sources,
|
||||
time_cutoff=time_cutoff,
|
||||
kg_entities=kg_entities,
|
||||
kg_relationships=kg_relationships,
|
||||
kg_terms=kg_terms,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
|
||||
|
||||
|
||||
def simple_vs_search(
|
||||
state: MainState,
|
||||
) -> Literal["process_kg_only_answers", "construct_deep_search_filters"]:
|
||||
if state.strategy == KGAnswerStrategy.DEEP or len(state.relationships) > 0:
|
||||
return "construct_deep_search_filters"
|
||||
else:
|
||||
return "process_kg_only_answers"
|
||||
|
||||
|
||||
def research_individual_object(
|
||||
state: MainState,
|
||||
# ) -> list[Send | Hashable] | Literal["individual_deep_search"]:
|
||||
) -> list[Send | Hashable]:
|
||||
edge_start_time = datetime.now()
|
||||
|
||||
# if (
|
||||
# not state.div_con_entities
|
||||
# or not state.broken_down_question
|
||||
# or not state.vespa_filter_results
|
||||
# ):
|
||||
# return "individual_deep_search"
|
||||
|
||||
# else:
|
||||
|
||||
assert state.div_con_entities is not None
|
||||
assert state.broken_down_question is not None
|
||||
assert state.vespa_filter_results is not None
|
||||
|
||||
return [
|
||||
Send(
|
||||
"process_individual_deep_search",
|
||||
ResearchObjectInput(
|
||||
entity=entity,
|
||||
broken_down_question=state.broken_down_question,
|
||||
vespa_filter_results=state.vespa_filter_results,
|
||||
source_division=state.source_division,
|
||||
log_messages=[
|
||||
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for entity in state.div_con_entities
|
||||
]
|
||||
@@ -1,138 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.kb_search.conditional_edges import (
|
||||
research_individual_object,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.conditional_edges import simple_vs_search
|
||||
from onyx.agents.agent_search.kb_search.nodes.a1_extract_ert import extract_ert
|
||||
from onyx.agents.agent_search.kb_search.nodes.a2_analyze import analyze
|
||||
from onyx.agents.agent_search.kb_search.nodes.a3_generate_simple_sql import (
|
||||
generate_simple_sql,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.b1_construct_deep_search_filters import (
|
||||
construct_deep_search_filters,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.b2p_process_individual_deep_search import (
|
||||
process_individual_deep_search,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.b3_consoldidate_individual_deep_search import (
|
||||
consoldidate_individual_deep_search,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.c1_process_kg_only_answers import (
|
||||
process_kg_only_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.d1_generate_answer import generate_answer
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
test_mode = False
|
||||
|
||||
|
||||
def kb_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the knowledge graph search process.
|
||||
"""
|
||||
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
"extract_ert",
|
||||
extract_ert,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"generate_simple_sql",
|
||||
generate_simple_sql,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"analyze",
|
||||
analyze,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"generate_answer",
|
||||
generate_answer,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"construct_deep_search_filters",
|
||||
construct_deep_search_filters,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"process_individual_deep_search",
|
||||
process_individual_deep_search,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# "individual_deep_search",
|
||||
# individual_deep_search,
|
||||
# )
|
||||
|
||||
graph.add_node(
|
||||
"consoldidate_individual_deep_search",
|
||||
consoldidate_individual_deep_search,
|
||||
)
|
||||
|
||||
graph.add_node("process_kg_only_answers", process_kg_only_answers)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="extract_ert")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="extract_ert",
|
||||
end_key="analyze",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="analyze",
|
||||
end_key="generate_simple_sql",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges("generate_simple_sql", simple_vs_search)
|
||||
|
||||
graph.add_edge(start_key="process_kg_only_answers", end_key="generate_answer")
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="construct_deep_search_filters",
|
||||
# end_key="process_individual_deep_search",
|
||||
# )
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="construct_deep_search_filters",
|
||||
path=research_individual_object,
|
||||
path_map=["process_individual_deep_search"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="process_individual_deep_search",
|
||||
end_key="consoldidate_individual_deep_search",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="individual_deep_search",
|
||||
# end_key="consoldidate_individual_deep_search",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="consoldidate_individual_deep_search", end_key="generate_answer"
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
@@ -1,217 +0,0 @@
|
||||
from typing import Set
|
||||
|
||||
from onyx.agents.agent_search.kb_search.models import KGExpendedGraphObjects
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.relationships import get_relationships_of_entity
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _check_entities_disconnected(
|
||||
current_entities: list[str], current_relationships: list[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if all entities in current_entities are disconnected via the given relationships.
|
||||
Relationships are in the format: source_entity__relationship_name__target_entity
|
||||
|
||||
Args:
|
||||
current_entities: List of entity IDs to check connectivity for
|
||||
current_relationships: List of relationships in format source__relationship__target
|
||||
|
||||
Returns:
|
||||
bool: True if all entities are disconnected, False otherwise
|
||||
"""
|
||||
if not current_entities:
|
||||
return True
|
||||
|
||||
# Create a graph representation using adjacency list
|
||||
graph: dict[str, set[str]] = {entity: set() for entity in current_entities}
|
||||
|
||||
# Build the graph from relationships
|
||||
for relationship in current_relationships:
|
||||
try:
|
||||
source, _, target = relationship.split("__")
|
||||
if source in graph and target in graph:
|
||||
graph[source].add(target)
|
||||
graph[target].add(source) # Add reverse edge since graph is undirected
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid relationship format: {relationship}")
|
||||
|
||||
# Use BFS to check if all entities are connected
|
||||
visited: set[str] = set()
|
||||
start_entity = current_entities[0]
|
||||
|
||||
def _bfs(start: str) -> None:
|
||||
queue = [start]
|
||||
visited.add(start)
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for neighbor in graph[current]:
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append(neighbor)
|
||||
|
||||
# Start BFS from the first entity
|
||||
_bfs(start_entity)
|
||||
|
||||
logger.debug(f"Number of visited entities: {len(visited)}")
|
||||
|
||||
# Check if all current_entities are in visited
|
||||
return not all(entity in visited for entity in current_entities)
|
||||
|
||||
|
||||
def create_minimal_connected_query_graph(
|
||||
entities: list[str], relationships: list[str], max_depth: int = 2
|
||||
) -> KGExpendedGraphObjects:
|
||||
"""
|
||||
Find the minimal subgraph that connects all input entities, using only general entities
|
||||
(<entity_type>:*) as intermediate nodes. The subgraph will include only the relationships
|
||||
necessary to connect all input entities through the shortest possible paths.
|
||||
|
||||
Args:
|
||||
entities: Initial list of entity IDs
|
||||
relationships: Initial list of relationships in format source__relationship__target
|
||||
max_depth: Maximum depth to expand the graph (default: 2)
|
||||
|
||||
Returns:
|
||||
KGExpendedGraphObjects containing expanded entities and relationships
|
||||
"""
|
||||
# Create copies of input lists to avoid modifying originals
|
||||
expanded_entities = entities.copy()
|
||||
expanded_relationships = relationships.copy()
|
||||
|
||||
# Keep track of original entities
|
||||
original_entities = set(entities)
|
||||
|
||||
# Build initial graph from existing relationships
|
||||
graph: dict[str, set[tuple[str, str]]] = {
|
||||
entity: set() for entity in expanded_entities
|
||||
}
|
||||
for rel in relationships:
|
||||
try:
|
||||
source, rel_name, target = rel.split("__")
|
||||
if source in graph and target in graph:
|
||||
graph[source].add((target, rel_name))
|
||||
graph[target].add((source, rel_name))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# For each depth level
|
||||
counter = 0
|
||||
while counter < max_depth:
|
||||
# Find all connected components in the current graph
|
||||
components = []
|
||||
visited = set()
|
||||
|
||||
def dfs(node: str, component: set[str]) -> None:
|
||||
visited.add(node)
|
||||
component.add(node)
|
||||
for neighbor, _ in graph.get(node, set()):
|
||||
if neighbor not in visited:
|
||||
dfs(neighbor, component)
|
||||
|
||||
# Find all components
|
||||
for entity in expanded_entities:
|
||||
if entity not in visited:
|
||||
component: Set[str] = set()
|
||||
dfs(entity, component)
|
||||
components.append(component)
|
||||
|
||||
# If we only have one component, we're done
|
||||
if len(components) == 1:
|
||||
break
|
||||
|
||||
# Find the shortest path between any two components using general entities
|
||||
shortest_path = None
|
||||
shortest_path_length = float("inf")
|
||||
|
||||
for comp1 in components:
|
||||
for comp2 in components:
|
||||
if comp1 == comp2:
|
||||
continue
|
||||
|
||||
# Try to find path between entities in different components
|
||||
for entity1 in comp1:
|
||||
if not any(e in original_entities for e in comp1):
|
||||
continue
|
||||
|
||||
# entity1_type = entity1.split(":")[0]
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entity1_rels = get_relationships_of_entity(db_session, entity1)
|
||||
|
||||
for rel1 in entity1_rels:
|
||||
try:
|
||||
source1, rel_name1, target1 = rel1.split("__")
|
||||
if source1 != entity1:
|
||||
continue
|
||||
|
||||
target1_type = target1.split(":")[0]
|
||||
general_target = f"{target1_type}:*"
|
||||
|
||||
# Try to find path from general_target to comp2
|
||||
for entity2 in comp2:
|
||||
if not any(e in original_entities for e in comp2):
|
||||
continue
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entity2_rels = get_relationships_of_entity(
|
||||
db_session, entity2
|
||||
)
|
||||
|
||||
for rel2 in entity2_rels:
|
||||
try:
|
||||
source2, rel_name2, target2 = rel2.split("__")
|
||||
if target2 != entity2:
|
||||
continue
|
||||
|
||||
source2_type = source2.split(":")[0]
|
||||
general_source = f"{source2_type}:*"
|
||||
|
||||
if general_target == general_source:
|
||||
# Found a path of length 2
|
||||
path = [
|
||||
(entity1, rel_name1, general_target),
|
||||
(general_target, rel_name2, entity2),
|
||||
]
|
||||
if len(path) < shortest_path_length:
|
||||
shortest_path = path
|
||||
shortest_path_length = len(path)
|
||||
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# If we found a path, add it to our graph
|
||||
if shortest_path:
|
||||
for source, rel_name, target in shortest_path:
|
||||
# Add general entity if needed
|
||||
if ":*" in source and source not in expanded_entities:
|
||||
expanded_entities.append(source)
|
||||
if ":*" in target and target not in expanded_entities:
|
||||
expanded_entities.append(target)
|
||||
|
||||
# Add relationship
|
||||
rel = f"{source}__{rel_name}__{target}"
|
||||
if rel not in expanded_relationships:
|
||||
expanded_relationships.append(rel)
|
||||
|
||||
# Update graph
|
||||
if source not in graph:
|
||||
graph[source] = set()
|
||||
if target not in graph:
|
||||
graph[target] = set()
|
||||
graph[source].add((target, rel_name))
|
||||
graph[target].add((source, rel_name))
|
||||
|
||||
counter += 1
|
||||
|
||||
logger.debug(f"Number of expanded entities: {len(expanded_entities)}")
|
||||
logger.debug(f"Number of expanded relationships: {len(expanded_relationships)}")
|
||||
|
||||
return KGExpendedGraphObjects(
|
||||
entities=expanded_entities, relationships=expanded_relationships
|
||||
)
|
||||
@@ -1,34 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import YesNoEnum
|
||||
|
||||
|
||||
class KGQuestionEntityExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
terms: list[str]
|
||||
time_filter: str | None
|
||||
|
||||
|
||||
class KGAnswerApproach(BaseModel):
|
||||
strategy: KGAnswerStrategy
|
||||
format: KGAnswerFormat
|
||||
broken_down_question: str | None = None
|
||||
divide_and_conquer: YesNoEnum | None = None
|
||||
|
||||
|
||||
class KGQuestionRelationshipExtractionResult(BaseModel):
|
||||
relationships: list[str]
|
||||
|
||||
|
||||
class KGQuestionExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
terms: list[str]
|
||||
time_filter: str | None
|
||||
|
||||
|
||||
class KGExpendedGraphObjects(BaseModel):
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
@@ -1,240 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
|
||||
from onyx.agents.agent_search.kb_search.models import (
|
||||
KGQuestionRelationshipExtractionResult,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import (
|
||||
ERTExtractionUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.relationships import get_allowed_relationship_type_pairs
|
||||
from onyx.kg.extractions.extraction_processing import get_entity_types_str
|
||||
from onyx.kg.extractions.extraction_processing import get_relationship_types_str
|
||||
from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def extract_ert(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ERTExtractionUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
|
||||
# first four lines duplicates from generate_initial_answer
|
||||
question = graph_config.inputs.search_request.query
|
||||
today_date = datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
all_entity_types = get_entity_types_str(active=None)
|
||||
all_relationship_types = get_relationship_types_str(active=None)
|
||||
|
||||
### get the entities, terms, and filters
|
||||
|
||||
query_extraction_pre_prompt = QUERY_ENTITY_EXTRACTION_PROMPT.format(
|
||||
entity_types=all_entity_types,
|
||||
relationship_types=all_relationship_types,
|
||||
)
|
||||
|
||||
query_extraction_prompt = query_extraction_pre_prompt.replace(
|
||||
"---content---", question
|
||||
).replace("---today_date---", today_date)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=query_extraction_prompt,
|
||||
)
|
||||
]
|
||||
fast_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
entity_extraction_result = (
|
||||
KGQuestionEntityExtractionResult.model_validate_json(cleaned_response)
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
entity_extraction_result = KGQuestionEntityExtractionResult(
|
||||
entities=[],
|
||||
terms=[],
|
||||
time_filter="",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
entity_extraction_result = KGQuestionEntityExtractionResult(
|
||||
entities=[],
|
||||
terms=[],
|
||||
time_filter="",
|
||||
)
|
||||
|
||||
ert_entities_string = f"Entities: {entity_extraction_result.entities}\n"
|
||||
# ert_terms_string = f"Terms: {entity_extraction_result.terms}"
|
||||
# ert_time_filter_string = f"Time Filter: {entity_extraction_result.time_filter}\n"
|
||||
|
||||
### get the relationships
|
||||
|
||||
# find the relationship types that match the extracted entity types
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
allowed_relationship_pairs = get_allowed_relationship_type_pairs(
|
||||
db_session, entity_extraction_result.entities
|
||||
)
|
||||
|
||||
query_relationship_extraction_prompt = (
|
||||
QUERY_RELATIONSHIP_EXTRACTION_PROMPT.replace("---question---", question)
|
||||
.replace("---today_date---", today_date)
|
||||
.replace(
|
||||
"---relationship_type_options---",
|
||||
" - " + "\n - ".join(allowed_relationship_pairs),
|
||||
)
|
||||
.replace("---identified_entities---", ert_entities_string)
|
||||
.replace("---entity_types---", all_entity_types)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=query_relationship_extraction_prompt,
|
||||
)
|
||||
]
|
||||
fast_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
cleaned_response = cleaned_response.replace("{{", '{"')
|
||||
cleaned_response = cleaned_response.replace("}}", '"}')
|
||||
|
||||
try:
|
||||
relationship_extraction_result = (
|
||||
KGQuestionRelationshipExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
|
||||
relationships=[],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
|
||||
relationships=[],
|
||||
)
|
||||
|
||||
# ert_relationships_string = (
|
||||
# f"Relationships: {relationship_extraction_result.relationships}\n"
|
||||
# )
|
||||
|
||||
##
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=ert_entities_string,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=ert_relationships_string,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=ert_terms_string,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=ert_time_filter_string,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
# dispatch_main_answer_stop_info(0, writer)
|
||||
|
||||
return ERTExtractionUpdate(
|
||||
entities_types_str=all_entity_types,
|
||||
entities=entity_extraction_result.entities,
|
||||
relationships=relationship_extraction_result.relationships,
|
||||
terms=entity_extraction_result.terms,
|
||||
time_filter=entity_extraction_result.time_filter,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="extract entities terms",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,210 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
create_minimal_connected_query_graph,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
|
||||
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import YesNoEnum
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.kg.clustering.normalizations import normalize_entities
|
||||
from onyx.kg.clustering.normalizations import normalize_relationships
|
||||
from onyx.kg.clustering.normalizations import normalize_terms
|
||||
from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def analyze(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> AnalysisUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
entities = state.entities
|
||||
relationships = state.relationships
|
||||
terms = state.terms
|
||||
time_filter = state.time_filter
|
||||
|
||||
normalized_entities = normalize_entities(entities)
|
||||
normalized_relationships = normalize_relationships(
|
||||
relationships, normalized_entities.entity_normalization_map
|
||||
)
|
||||
normalized_terms = normalize_terms(terms)
|
||||
normalized_time_filter = time_filter
|
||||
|
||||
# Expand the entities and relationships to make sure that entities are connected
|
||||
|
||||
graph_expansion = create_minimal_connected_query_graph(
|
||||
normalized_entities.entities,
|
||||
normalized_relationships.relationships,
|
||||
max_depth=2,
|
||||
)
|
||||
|
||||
query_graph_entities = graph_expansion.entities
|
||||
query_graph_relationships = graph_expansion.relationships
|
||||
|
||||
# Evaluate whether a search needs to be done after identifying all entities and relationships
|
||||
|
||||
strategy_generation_prompt = (
|
||||
STRATEGY_GENERATION_PROMPT.replace(
|
||||
"---entities---", "\n".join(query_graph_entities)
|
||||
)
|
||||
.replace("---relationships---", "\n".join(query_graph_relationships))
|
||||
.replace("---terms---", "\n".join(normalized_terms.terms))
|
||||
.replace("---question---", question)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=strategy_generation_prompt,
|
||||
)
|
||||
]
|
||||
# fast_llm = graph_config.tooling.fast_llm
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
20,
|
||||
# fast_llm.invoke,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=5,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
approach_extraction_result = KGAnswerApproach.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
strategy = approach_extraction_result.strategy
|
||||
output_format = approach_extraction_result.format
|
||||
broken_down_question = approach_extraction_result.broken_down_question
|
||||
divide_and_conquer = approach_extraction_result.divide_and_conquer
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
strategy = KGAnswerStrategy.DEEP
|
||||
output_format = KGAnswerFormat.TEXT
|
||||
broken_down_question = None
|
||||
divide_and_conquer = YesNoEnum.NO
|
||||
if strategy is None or output_format is None:
|
||||
raise ValueError(f"Invalid strategy: {cleaned_response}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in strategy generation: {e}")
|
||||
raise e
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece="\n".join(normalized_entities.entities),
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece="\n".join(normalized_relationships.relationships),
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece="\n".join(query_graph_entities),
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece="\n".join(query_graph_relationships),
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=strategy.value,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=output_format.value,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# dispatch_main_answer_stop_info(0, writer)
|
||||
|
||||
return AnalysisUpdate(
|
||||
normalized_core_entities=normalized_entities.entities,
|
||||
normalized_core_relationships=normalized_relationships.relationships,
|
||||
query_graph_entities=query_graph_entities,
|
||||
query_graph_relationships=query_graph_relationships,
|
||||
normalized_terms=normalized_terms.terms,
|
||||
normalized_time_filter=normalized_time_filter,
|
||||
strategy=strategy,
|
||||
broken_down_question=broken_down_question,
|
||||
output_format=output_format,
|
||||
divide_and_conquer=divide_and_conquer,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="analyze",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,224 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import SQLSimpleGenerationUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_PROMPT
|
||||
from onyx.prompts.kg_prompts import SQL_AGGREGATION_REMOVAL_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _sql_is_aggregate_query(sql_statement: str) -> bool:
|
||||
return any(
|
||||
agg_func in sql_statement.upper()
|
||||
for agg_func in ["COUNT(", "MAX(", "MIN(", "AVG(", "SUM("]
|
||||
)
|
||||
|
||||
|
||||
def _remove_aggregation(sql_statement: str, llm: LLM) -> str:
|
||||
"""
|
||||
Remove aggregate functions from the SQL statement.
|
||||
"""
|
||||
|
||||
sql_aggregation_removal_prompt = SQL_AGGREGATION_REMOVAL_PROMPT.replace(
|
||||
"---sql_statement---", sql_statement
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=sql_aggregation_removal_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=25,
|
||||
max_tokens=800,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
sql_statement = cleaned_response.split("SQL:")[1].strip()
|
||||
sql_statement = sql_statement.split(";")[0].strip() + ";"
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in strategy generation: {e}")
|
||||
raise e
|
||||
|
||||
return sql_statement
|
||||
|
||||
|
||||
def generate_simple_sql(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> SQLSimpleGenerationUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
entities_types_str = state.entities_types_str
|
||||
state.strategy
|
||||
state.output_format
|
||||
|
||||
simple_sql_prompt = (
|
||||
SIMPLE_SQL_PROMPT.replace("---entities_types---", entities_types_str)
|
||||
.replace("---question---", question)
|
||||
.replace("---query_entities---", "\n".join(state.query_graph_entities))
|
||||
.replace(
|
||||
"---query_relationships---", "\n".join(state.query_graph_relationships)
|
||||
)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=simple_sql_prompt,
|
||||
)
|
||||
]
|
||||
fast_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=25,
|
||||
max_tokens=800,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
sql_statement = cleaned_response.split("SQL:")[1].strip()
|
||||
sql_statement = sql_statement.split(";")[0].strip() + ";"
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
|
||||
# reasoning = cleaned_response.split("SQL:")[0].strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in strategy generation: {e}")
|
||||
raise e
|
||||
|
||||
if _sql_is_aggregate_query(sql_statement):
|
||||
individualized_sql_query = _remove_aggregation(sql_statement, llm=fast_llm)
|
||||
else:
|
||||
individualized_sql_query = None
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=reasoning,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=cleaned_response,
|
||||
# level=0,
|
||||
# level_question_num=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# CRITICAL: EXECUTION OF SQL NEEDS TO ME MADE SAFE FOR PRODUCTION
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
result = db_session.execute(text(sql_statement))
|
||||
# Handle scalar results (like COUNT)
|
||||
if sql_statement.upper().startswith("SELECT COUNT"):
|
||||
scalar_result = result.scalar()
|
||||
query_results = (
|
||||
[{"count": int(scalar_result) - 1}]
|
||||
if scalar_result is not None
|
||||
else []
|
||||
)
|
||||
else:
|
||||
# Handle regular row results
|
||||
rows = result.fetchall()
|
||||
query_results = [dict(row._mapping) for row in rows]
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing SQL query: {e}")
|
||||
|
||||
raise e
|
||||
|
||||
if (
|
||||
individualized_sql_query is not None
|
||||
and individualized_sql_query != sql_statement
|
||||
):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
result = db_session.execute(text(individualized_sql_query))
|
||||
# Handle scalar results (like COUNT)
|
||||
if individualized_sql_query.upper().startswith("SELECT COUNT"):
|
||||
scalar_result = result.scalar()
|
||||
individualized_query_results = (
|
||||
[{"count": int(scalar_result) - 1}]
|
||||
if scalar_result is not None
|
||||
else []
|
||||
)
|
||||
else:
|
||||
# Handle regular row results
|
||||
rows = result.fetchall()
|
||||
individualized_query_results = [dict(row._mapping) for row in rows]
|
||||
except Exception as e:
|
||||
# No stopping here, the individualized SQL query is not mandatory
|
||||
logger.error(f"Error executing Individualized SQL query: {e}")
|
||||
individualized_query_results = None
|
||||
|
||||
else:
|
||||
individualized_query_results = None
|
||||
|
||||
# write_custom_event(
|
||||
# "initial_agent_answer",
|
||||
# AgentAnswerPiece(
|
||||
# answer_piece=str(query_results),
|
||||
# level=0,
|
||||
# answer_type="agent_level_answer",
|
||||
# ),
|
||||
# writer,
|
||||
# )
|
||||
|
||||
# dispatch_main_answer_stop_info(0, writer)
|
||||
|
||||
logger.info(f"query_results: {query_results}")
|
||||
|
||||
return SQLSimpleGenerationUpdate(
|
||||
sql_query=sql_statement,
|
||||
query_results=query_results,
|
||||
individualized_sql_query=individualized_sql_query,
|
||||
individualized_query_results=individualized_query_results,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="generate simple sql",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,158 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import KGVespaFilterResults
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_entity_types_with_grounded_source_name
|
||||
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def construct_deep_search_filters(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> DeepSearchFilterUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
|
||||
entities_types_str = state.entities_types_str
|
||||
entities = state.query_graph_entities
|
||||
relationships = state.query_graph_relationships
|
||||
simple_sql_query = state.sql_query
|
||||
|
||||
search_filter_construction_prompt = (
|
||||
SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
|
||||
"---entity_type_descriptions---",
|
||||
entities_types_str,
|
||||
)
|
||||
.replace(
|
||||
"---entity_filters---",
|
||||
"\n".join(entities),
|
||||
)
|
||||
.replace(
|
||||
"---relationship_filters---",
|
||||
"\n".join(relationships),
|
||||
)
|
||||
.replace(
|
||||
"---sql_query---",
|
||||
simple_sql_query or "(no SQL generated)",
|
||||
)
|
||||
.replace(
|
||||
"---question---",
|
||||
question,
|
||||
)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=search_filter_construction_prompt,
|
||||
)
|
||||
]
|
||||
llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
cleaned_response = cleaned_response.replace("{{", '{"')
|
||||
cleaned_response = cleaned_response.replace("}}", '"}')
|
||||
|
||||
try:
|
||||
vespa_filter_results = KGVespaFilterResults.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
vespa_filter_results = KGVespaFilterResults(
|
||||
entity_filters=[],
|
||||
relationship_filters=[],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
vespa_filter_results = KGVespaFilterResults(
|
||||
entity_filters=[],
|
||||
relationship_filters=[],
|
||||
)
|
||||
|
||||
if (
|
||||
state.individualized_query_results
|
||||
and len(state.individualized_query_results) > 0
|
||||
):
|
||||
div_con_entities = [
|
||||
x["id_name"]
|
||||
for x in state.individualized_query_results
|
||||
if x["id_name"] is not None and "*" not in x["id_name"]
|
||||
]
|
||||
elif state.query_results and len(state.query_results) > 0:
|
||||
div_con_entities = [
|
||||
x["id_name"]
|
||||
for x in state.query_results
|
||||
if x["id_name"] is not None and "*" not in x["id_name"]
|
||||
]
|
||||
else:
|
||||
div_con_entities = []
|
||||
|
||||
div_con_entities = list(set(div_con_entities))
|
||||
|
||||
logger.info(f"div_con_entities: {div_con_entities}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
double_grounded_entity_types = get_entity_types_with_grounded_source_name(
|
||||
db_session
|
||||
)
|
||||
|
||||
source_division = False
|
||||
|
||||
if div_con_entities:
|
||||
for entity_type in double_grounded_entity_types:
|
||||
if entity_type.grounded_source_name.lower() in div_con_entities[0].lower():
|
||||
source_division = True
|
||||
break
|
||||
else:
|
||||
raise ValueError("No div_con_entities found")
|
||||
|
||||
return DeepSearchFilterUpdate(
|
||||
vespa_filter_results=vespa_filter_results,
|
||||
div_con_entities=div_con_entities,
|
||||
source_division=source_division,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="construct deep search filters",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,133 +0,0 @@
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import research
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.prompts.kg_prompts import KG_OBJECT_SOURCE_RESEARCH_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def process_individual_deep_search(
|
||||
state: ResearchObjectInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ResearchObjectUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = state.broken_down_question
|
||||
source_division = state.source_division
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError("search_tool is not provided")
|
||||
|
||||
object = state.entity.replace(":", ": ").lower()
|
||||
|
||||
if source_division:
|
||||
extended_question = question
|
||||
else:
|
||||
extended_question = f"{question} in regards to {object}"
|
||||
|
||||
kg_entity_filters = copy.deepcopy(
|
||||
state.vespa_filter_results.entity_filters + [state.entity]
|
||||
)
|
||||
kg_relationship_filters = copy.deepcopy(
|
||||
state.vespa_filter_results.relationship_filters
|
||||
)
|
||||
|
||||
logger.info("Research for object: " + object)
|
||||
logger.info(f"kg_entity_filters: {kg_entity_filters}")
|
||||
logger.info(f"kg_relationship_filters: {kg_relationship_filters}")
|
||||
|
||||
# Add random wait between 1-3 seconds
|
||||
# time.sleep(random.uniform(0, 3))
|
||||
|
||||
retrieved_docs = research(
|
||||
question=extended_question,
|
||||
kg_entities=kg_entity_filters,
|
||||
kg_relationships=kg_relationship_filters,
|
||||
search_tool=search_tool,
|
||||
)
|
||||
|
||||
document_texts_list = []
|
||||
for doc_num, doc in enumerate(retrieved_docs):
|
||||
chunk_text = "Document " + str(doc_num + 1) + ":\n" + doc.content
|
||||
document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
# Built prompt
|
||||
|
||||
datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
kg_object_source_research_prompt = KG_OBJECT_SOURCE_RESEARCH_PROMPT.format(
|
||||
question=extended_question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
|
||||
# Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=kg_object_source_research_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
# fast_llm = graph_config.tooling.fast_llm
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
llm = primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
30,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
object_research_results = str(llm_response.content).replace("```json\n", "")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in research_object_source: {e}")
|
||||
|
||||
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
|
||||
|
||||
return ResearchObjectUpdate(
|
||||
research_object_results=[
|
||||
{
|
||||
"object": object.replace(":", ": ").capitalize(),
|
||||
"results": object_research_results,
|
||||
}
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="process individual deep search",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,125 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import KGVespaFilterResults
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def individual_deep_search(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> DeepSearchFilterUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
|
||||
entities_types_str = state.entities_types_str
|
||||
entities = state.query_graph_entities
|
||||
relationships = state.query_graph_relationships
|
||||
simple_sql_query = state.sql_query
|
||||
|
||||
search_filter_construction_prompt = (
|
||||
SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
|
||||
"---entity_type_descriptions---",
|
||||
entities_types_str,
|
||||
)
|
||||
.replace(
|
||||
"---entity_filters---",
|
||||
"\n".join(entities),
|
||||
)
|
||||
.replace(
|
||||
"---relationship_filters---",
|
||||
"\n".join(relationships),
|
||||
)
|
||||
.replace(
|
||||
"---sql_query---",
|
||||
simple_sql_query or "(no SQL generated)",
|
||||
)
|
||||
.replace(
|
||||
"---question---",
|
||||
question,
|
||||
)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=search_filter_construction_prompt,
|
||||
)
|
||||
]
|
||||
llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
cleaned_response = cleaned_response.replace("{{", '{"')
|
||||
cleaned_response = cleaned_response.replace("}}", '"}')
|
||||
|
||||
try:
|
||||
vespa_filter_results = KGVespaFilterResults.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
vespa_filter_results = KGVespaFilterResults(
|
||||
entity_filters=[],
|
||||
relationship_filters=[],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
vespa_filter_results = KGVespaFilterResults(
|
||||
entity_filters=[],
|
||||
relationship_filters=[],
|
||||
)
|
||||
|
||||
if state.query_results:
|
||||
div_con_entities = [
|
||||
x["id_name"] for x in state.query_results if x["id_name"] is not None
|
||||
]
|
||||
else:
|
||||
div_con_entities = []
|
||||
|
||||
return DeepSearchFilterUpdate(
|
||||
vespa_filter_results=vespa_filter_results,
|
||||
div_con_entities=div_con_entities,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="construct deep search filters",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,45 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def consoldidate_individual_deep_search(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ConsolidatedResearchUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
graph_config.inputs.search_request.query
|
||||
state.entities_types_str
|
||||
|
||||
research_object_results = state.research_object_results
|
||||
|
||||
consolidated_research_object_results_str = "\n".join(
|
||||
[f"{x['object']}: {x['results']}" for x in research_object_results]
|
||||
)
|
||||
|
||||
return ConsolidatedResearchUpdate(
|
||||
consolidated_research_object_results_str=consolidated_research_object_results_str,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="generate simple sql",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,107 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import ReferenceResults
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.document import get_base_llm_doc_information
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _general_format(result: dict[str, Any]) -> str:
|
||||
name = result.get("name")
|
||||
entity_type_id_name = result.get("entity_type_id_name")
|
||||
result.get("id_name")
|
||||
|
||||
if entity_type_id_name:
|
||||
return f"{entity_type_id_name.capitalize()}: {name}"
|
||||
else:
|
||||
return f"{name}"
|
||||
|
||||
|
||||
def _generate_reference_results(
|
||||
individualized_query_results: list[dict[str, Any]]
|
||||
) -> ReferenceResults:
|
||||
"""
|
||||
Generate reference results from the query results data string.
|
||||
"""
|
||||
|
||||
citations: list[str] = []
|
||||
general_entities = []
|
||||
|
||||
# get all entities that correspond to an Onu=yx document
|
||||
document_ids: list[str] = [
|
||||
cast(str, x.get("document_id"))
|
||||
for x in individualized_query_results
|
||||
if x.get("document_id")
|
||||
]
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
llm_doc_information_results = get_base_llm_doc_information(
|
||||
session, document_ids
|
||||
)
|
||||
|
||||
for llm_doc_information_result in llm_doc_information_results:
|
||||
citations.append(llm_doc_information_result.center_chunk.semantic_identifier)
|
||||
|
||||
for result in individualized_query_results:
|
||||
document_id: str | None = result.get("document_id")
|
||||
|
||||
if not document_id:
|
||||
general_entities.append(_general_format(result))
|
||||
|
||||
return ReferenceResults(citations=citations, general_entities=general_entities)
|
||||
|
||||
|
||||
def process_kg_only_answers(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ResultsDataUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
graph_config.inputs.search_request.query
|
||||
query_results = state.query_results
|
||||
individualized_query_results = state.individualized_query_results
|
||||
|
||||
query_results_list = []
|
||||
|
||||
if query_results:
|
||||
for query_result in query_results:
|
||||
query_results_list.append(str(query_result).replace(":", ": ").capitalize())
|
||||
else:
|
||||
raise ValueError("No query results were found")
|
||||
|
||||
query_results_data_str = "\n".join(query_results_list)
|
||||
|
||||
if individualized_query_results:
|
||||
reference_results = _generate_reference_results(individualized_query_results)
|
||||
else:
|
||||
reference_results = None
|
||||
|
||||
return ResultsDataUpdate(
|
||||
query_results_data_str=query_results_data_str,
|
||||
individualized_query_results_data_str="",
|
||||
reference_results=reference_results,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="kg query results data processing",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,157 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import MainOutput
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_EXAMPLES_PROMPT
|
||||
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT
|
||||
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_answer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> MainOutput:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
state.entities_types_str
|
||||
introductory_answer = state.query_results_data_str
|
||||
state.reference_results
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
if search_tool is None:
|
||||
raise ValueError("Search tool is not set")
|
||||
|
||||
consolidated_research_object_results_str = (
|
||||
state.consolidated_research_object_results_str
|
||||
)
|
||||
|
||||
question = graph_config.inputs.search_request.query
|
||||
|
||||
output_format = state.output_format
|
||||
if state.reference_results:
|
||||
examples = (
|
||||
state.reference_results.citations
|
||||
or state.reference_results.general_entities
|
||||
or []
|
||||
)
|
||||
research_results = "\n".join([f"- {example}" for example in examples])
|
||||
elif consolidated_research_object_results_str:
|
||||
research_results = consolidated_research_object_results_str
|
||||
else:
|
||||
research_results = ""
|
||||
|
||||
if research_results and introductory_answer:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_PROMPT.replace("---question---", question)
|
||||
.replace("---introductory_answer---", introductory_answer)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
.replace("---research_results---", research_results)
|
||||
)
|
||||
|
||||
elif not research_results and introductory_answer:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
|
||||
.replace("---introductory_answer---", introductory_answer)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
)
|
||||
elif research_results and not introductory_answer:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT.replace("---question---", question)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
.replace("---research_results---", research_results)
|
||||
)
|
||||
else:
|
||||
raise ValueError("No research results or introductory answer provided")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=output_format_prompt,
|
||||
)
|
||||
]
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
dispatch_timings: list[float] = []
|
||||
response: list[str] = []
|
||||
|
||||
def stream_answer() -> list[str]:
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=1000,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
# logger.debug(f"Answer piece: {content}")
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
30,
|
||||
stream_answer,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not generate the answer. Error {e}")
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_ANSWER,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
dispatch_main_answer_stop_info(0, writer)
|
||||
|
||||
return MainOutput(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="query completed",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,42 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
def research(
|
||||
question: str,
|
||||
search_tool: SearchTool,
|
||||
document_sources: list[DocumentSource] | None = None,
|
||||
time_cutoff: datetime | None = None,
|
||||
) -> list[LlmDoc]:
|
||||
# new db session to avoid concurrency issues
|
||||
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
retrieved_docs: list[LlmDoc] = []
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=False,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
document_sources=document_sources,
|
||||
time_cutoff=time_cutoff,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
|
||||
break
|
||||
return retrieved_docs
|
||||
@@ -1,127 +0,0 @@
|
||||
from enum import Enum
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import ReferenceResults
|
||||
|
||||
|
||||
### States ###
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class KGVespaFilterResults(BaseModel):
|
||||
entity_filters: list[str]
|
||||
relationship_filters: list[str]
|
||||
|
||||
|
||||
class KGAnswerStrategy(Enum):
|
||||
DEEP = "DEEP"
|
||||
SIMPLE = "SIMPLE"
|
||||
|
||||
|
||||
class KGAnswerFormat(Enum):
|
||||
LIST = "LIST"
|
||||
TEXT = "TEXT"
|
||||
|
||||
|
||||
class YesNoEnum(str, Enum):
|
||||
YES = "yes"
|
||||
NO = "no"
|
||||
|
||||
|
||||
class AnalysisUpdate(LoggerUpdate):
|
||||
normalized_core_entities: list[str] = []
|
||||
normalized_core_relationships: list[str] = []
|
||||
query_graph_entities: list[str] = []
|
||||
query_graph_relationships: list[str] = []
|
||||
normalized_terms: list[str] = []
|
||||
normalized_time_filter: str | None = None
|
||||
strategy: KGAnswerStrategy | None = None
|
||||
output_format: KGAnswerFormat | None = None
|
||||
broken_down_question: str | None = None
|
||||
divide_and_conquer: YesNoEnum | None = None
|
||||
|
||||
|
||||
class SQLSimpleGenerationUpdate(LoggerUpdate):
|
||||
sql_query: str | None = None
|
||||
query_results: list[Dict[Any, Any]] | None = None
|
||||
individualized_sql_query: str | None = None
|
||||
individualized_query_results: list[Dict[Any, Any]] | None = None
|
||||
|
||||
|
||||
class ConsolidatedResearchUpdate(LoggerUpdate):
|
||||
consolidated_research_object_results_str: str | None = None
|
||||
|
||||
|
||||
class DeepSearchFilterUpdate(LoggerUpdate):
|
||||
vespa_filter_results: KGVespaFilterResults | None = None
|
||||
div_con_entities: list[str] | None = None
|
||||
source_division: bool | None = None
|
||||
|
||||
|
||||
class ResearchObjectOutput(LoggerUpdate):
|
||||
research_object_results: Annotated[list[dict[str, Any]], add] = []
|
||||
|
||||
|
||||
class ERTExtractionUpdate(LoggerUpdate):
|
||||
entities_types_str: str = ""
|
||||
entities: list[str] = []
|
||||
relationships: list[str] = []
|
||||
terms: list[str] = []
|
||||
time_filter: str | None = None
|
||||
|
||||
|
||||
class ResultsDataUpdate(LoggerUpdate):
|
||||
query_results_data_str: str | None = None
|
||||
individualized_query_results_data_str: str | None = None
|
||||
reference_results: ReferenceResults | None = None
|
||||
|
||||
|
||||
class ResearchObjectUpdate(LoggerUpdate):
|
||||
research_object_results: Annotated[list[dict[str, Any]], add] = []
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
ERTExtractionUpdate,
|
||||
AnalysisUpdate,
|
||||
SQLSimpleGenerationUpdate,
|
||||
ResultsDataUpdate,
|
||||
ResearchObjectOutput,
|
||||
DeepSearchFilterUpdate,
|
||||
ResearchObjectUpdate,
|
||||
ConsolidatedResearchUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
|
||||
|
||||
class ResearchObjectInput(LoggerUpdate):
|
||||
entity: str
|
||||
broken_down_question: str
|
||||
vespa_filter_results: KGVespaFilterResults
|
||||
source_division: bool | None
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
@@ -12,21 +10,13 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoice
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
)
|
||||
from onyx.configs.chat_configs import USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
|
||||
from onyx.context.search.preprocessing.preprocessing import query_analysis
|
||||
from onyx.context.search.retrieval.search_runner import get_query_embedding
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
|
||||
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
|
||||
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
from onyx.tools.models import QueryExpansions
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
@@ -40,49 +30,6 @@ from shared_configs.model_server_models import Embedding
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
|
||||
# TODO: Add trimming logic
|
||||
history_segments = []
|
||||
for msg in prompt_builder.message_history:
|
||||
if isinstance(msg, HumanMessage):
|
||||
role = "User"
|
||||
elif isinstance(msg, AIMessage):
|
||||
role = "Assistant"
|
||||
else:
|
||||
continue
|
||||
history_segments.append(f"{role}:\n {msg.content}\n\n")
|
||||
return "\n".join(history_segments)
|
||||
|
||||
|
||||
def _expand_query(
|
||||
query: str,
|
||||
expansion_type: QueryExpansionType,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
) -> str:
|
||||
|
||||
history_str = _create_history_str(prompt_builder)
|
||||
|
||||
if history_str:
|
||||
if expansion_type == QueryExpansionType.KEYWORD:
|
||||
base_prompt = QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
|
||||
else:
|
||||
base_prompt = QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
|
||||
expansion_prompt = base_prompt.format(question=query, history=history_str)
|
||||
else:
|
||||
if expansion_type == QueryExpansionType.KEYWORD:
|
||||
base_prompt = QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
else:
|
||||
base_prompt = QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
expansion_prompt = base_prompt.format(question=query)
|
||||
|
||||
msg = HumanMessage(content=expansion_prompt)
|
||||
primary_llm, _ = get_default_llms()
|
||||
response = primary_llm.invoke([msg])
|
||||
rephrased_query: str = cast(str, response.content)
|
||||
|
||||
return rephrased_query
|
||||
|
||||
|
||||
# TODO: break this out into an implementation function
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
@@ -105,16 +52,7 @@ def choose_tool(
|
||||
|
||||
embedding_thread: TimeoutThread[Embedding] | None = None
|
||||
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
|
||||
expanded_keyword_thread: TimeoutThread[str] | None = None
|
||||
expanded_semantic_thread: TimeoutThread[str] | None = None
|
||||
override_kwargs: SearchToolOverrideKwargs | None = None
|
||||
|
||||
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
||||
|
||||
llm = agent_config.tooling.primary_llm
|
||||
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
|
||||
|
||||
if (
|
||||
not agent_config.behavior.use_agentic_search
|
||||
and agent_config.tooling.search_tool is not None
|
||||
@@ -134,20 +72,11 @@ def choose_tool(
|
||||
agent_config.inputs.search_request.query,
|
||||
)
|
||||
|
||||
if USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH:
|
||||
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
||||
|
||||
expanded_keyword_thread = run_in_background(
|
||||
_expand_query,
|
||||
agent_config.inputs.search_request.query,
|
||||
QueryExpansionType.KEYWORD,
|
||||
prompt_builder,
|
||||
)
|
||||
expanded_semantic_thread = run_in_background(
|
||||
_expand_query,
|
||||
agent_config.inputs.search_request.query,
|
||||
QueryExpansionType.SEMANTIC,
|
||||
prompt_builder,
|
||||
)
|
||||
llm = agent_config.tooling.primary_llm
|
||||
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
|
||||
|
||||
structured_response_format = agent_config.inputs.structured_response_format
|
||||
tools = [
|
||||
@@ -280,19 +209,6 @@ def choose_tool(
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
|
||||
if (
|
||||
selected_tool.name == SearchTool._NAME
|
||||
and expanded_keyword_thread
|
||||
and expanded_semantic_thread
|
||||
):
|
||||
keyword_expansion = wait_on_background(expanded_keyword_thread)
|
||||
semantic_expansion = wait_on_background(expanded_semantic_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.expanded_queries = QueryExpansions(
|
||||
keywords_expansions=[keyword_expansion],
|
||||
semantic_expansions=[semantic_expansion],
|
||||
)
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
|
||||
@@ -9,7 +9,6 @@ from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.context.search.utils import dedupe_documents
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
@@ -51,16 +50,16 @@ def basic_use_tool_response(
|
||||
|
||||
final_search_results = []
|
||||
initial_search_results = []
|
||||
initial_search_document_ids: set[str] = set()
|
||||
for yield_item in tool_call_responses:
|
||||
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
search_response_summary = cast(SearchResponseSummary, yield_item.response)
|
||||
# use same function from _handle_search_tool_response_summary
|
||||
initial_search_results = [
|
||||
section_to_llm_doc(section)
|
||||
for section in dedupe_documents(search_response_summary.top_sections)[0]
|
||||
]
|
||||
for section in search_response_summary.top_sections:
|
||||
if section.center_chunk.document_id not in initial_search_document_ids:
|
||||
initial_search_document_ids.add(section.center_chunk.document_id)
|
||||
initial_search_results.append(section_to_llm_doc(section))
|
||||
|
||||
new_tool_call_chunk = AIMessageChunk(content="")
|
||||
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
||||
|
||||
@@ -12,16 +12,12 @@ from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
|
||||
divide_and_conquer_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.graph_builder import dc_graph_builder
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
|
||||
from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||
main_graph_builder as main_graph_builder_a,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
MainInput as MainInput,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
@@ -90,7 +86,7 @@ def _parse_agent_event(
|
||||
def manage_sync_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
graph_input: BasicInput | MainInput | KBMainInput | DCMainInput,
|
||||
graph_input: BasicInput | MainInput | DCMainInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
message_id = config.persistence.message_id if config.persistence else None
|
||||
for event in compiled_graph.stream(
|
||||
@@ -104,7 +100,7 @@ def manage_sync_streaming(
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
input: BasicInput | MainInput | KBMainInput | DCMainInput,
|
||||
input: BasicInput | MainInput | DCMainInput,
|
||||
) -> AnswerStream:
|
||||
config.behavior.perform_initial_search_decomposition = (
|
||||
INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
@@ -154,15 +150,6 @@ def run_basic_graph(
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_kb_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
graph = kb_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = KBMainInput(log_messages=[])
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_dc_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
|
||||
@@ -27,8 +27,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_sub_question_answer_prompt(
|
||||
question: str,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -154,14 +153,3 @@ class AnswerGenerationDocuments(BaseModel):
|
||||
|
||||
|
||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
|
||||
|
||||
class QueryExpansionType(Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
|
||||
|
||||
class ReferenceResults(BaseModel):
|
||||
# citations: list[InferenceSection]
|
||||
citations: list[str]
|
||||
general_entities: list[str]
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.indexing.utils import is_in_repeated_error_state
|
||||
from onyx.background.celery.tasks.indexing.utils import should_index
|
||||
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
|
||||
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
|
||||
@@ -55,12 +54,11 @@ from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import set_cc_pair_repeated_error_state
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
@@ -243,16 +241,6 @@ def monitor_ccpair_indexing_taskset(
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# if the CC Pair is `SCHEDULED`, moved it to `INITIAL_INDEXING`. A CC Pair
|
||||
# should only ever be `SCHEDULED` if it's a new connector.
|
||||
cc_pair = get_connector_credential_pair_from_id(db_session, cc_pair_id)
|
||||
if cc_pair is None:
|
||||
raise RuntimeError(f"CC Pair {cc_pair_id} not found")
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.SCHEDULED:
|
||||
cc_pair.status = ConnectorCredentialPairStatus.INITIAL_INDEXING
|
||||
db_session.commit()
|
||||
|
||||
elapsed_started_str = None
|
||||
if payload.started:
|
||||
elapsed_started = datetime.now(timezone.utc) - payload.started
|
||||
@@ -367,22 +355,6 @@ def monitor_ccpair_indexing_taskset(
|
||||
|
||||
redis_connector_index.reset()
|
||||
|
||||
# mark the CC Pair as `ACTIVE` if it's not already
|
||||
if (
|
||||
# it should never technically be in this state, but we'll handle it anyway
|
||||
cc_pair.status == ConnectorCredentialPairStatus.SCHEDULED
|
||||
or cc_pair.status == ConnectorCredentialPairStatus.INITIAL_INDEXING
|
||||
):
|
||||
cc_pair.status = ConnectorCredentialPairStatus.ACTIVE
|
||||
db_session.commit()
|
||||
|
||||
# if the index attempt is successful, clear the repeated error state
|
||||
if cc_pair.in_repeated_error_state:
|
||||
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
|
||||
if index_attempt and index_attempt.status.is_successful():
|
||||
cc_pair.in_repeated_error_state = False
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
@@ -469,21 +441,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
# mark CC Pairs that are repeatedly failing as in repeated error state
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
if is_in_repeated_error_state(
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=current_search_settings.id,
|
||||
db_session=db_session,
|
||||
):
|
||||
set_cc_pair_repeated_error_state(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
in_repeated_error_state=True,
|
||||
)
|
||||
|
||||
# kick off index attempts
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
@@ -523,8 +480,13 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
)
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
if not should_index(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
secondary_index_building=len(search_settings_list) > 1,
|
||||
db_session=db_session,
|
||||
@@ -532,6 +494,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
task_logger.info(
|
||||
f"check_for_indexing - Not indexing cc_pair_id: {cc_pair_id} "
|
||||
f"search_settings={search_settings_instance.id}, "
|
||||
f"last_attempt={last_attempt.id if last_attempt else None}, "
|
||||
f"secondary_index_building={len(search_settings_list) > 1}"
|
||||
)
|
||||
continue
|
||||
@@ -539,6 +502,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
task_logger.info(
|
||||
f"check_for_indexing - Will index cc_pair_id: {cc_pair_id} "
|
||||
f"search_settings={search_settings_instance.id}, "
|
||||
f"last_attempt={last_attempt.id if last_attempt else None}, "
|
||||
f"secondary_index_building={len(search_settings_list) > 1}"
|
||||
)
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -32,8 +31,6 @@ from onyx.db.index_attempt import create_index_attempt
|
||||
from onyx.db.index_attempt import delete_index_attempt
|
||||
from onyx.db.index_attempt import get_all_index_attempts_by_status
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
@@ -47,8 +44,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE = 5
|
||||
|
||||
|
||||
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
|
||||
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
|
||||
@@ -351,42 +346,9 @@ def validate_indexing_fences(
|
||||
return
|
||||
|
||||
|
||||
def is_in_repeated_error_state(
|
||||
cc_pair_id: int, search_settings_id: int, db_session: Session
|
||||
) -> bool:
|
||||
"""Checks if the cc pair / search setting combination is in a repeated error state."""
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise RuntimeError(
|
||||
f"is_in_repeated_error_state - could not find cc_pair with id={cc_pair_id}"
|
||||
)
|
||||
|
||||
# if the connector doesn't have a refresh_freq, a single failed attempt is enough
|
||||
number_of_failed_attempts_in_a_row_needed = (
|
||||
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE
|
||||
if cc_pair.connector.refresh_freq is not None
|
||||
else 1
|
||||
)
|
||||
|
||||
most_recent_index_attempts = get_recent_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=search_settings_id,
|
||||
limit=number_of_failed_attempts_in_a_row_needed,
|
||||
db_session=db_session,
|
||||
)
|
||||
return len(
|
||||
most_recent_index_attempts
|
||||
) >= number_of_failed_attempts_in_a_row_needed and all(
|
||||
attempt.status == IndexingStatus.FAILED
|
||||
for attempt in most_recent_index_attempts
|
||||
)
|
||||
|
||||
|
||||
def should_index(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
@@ -400,16 +362,6 @@ def should_index(
|
||||
Return True if we should try to index, False if not.
|
||||
"""
|
||||
connector = cc_pair.connector
|
||||
last_index_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings_instance.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
all_recent_errored = is_in_repeated_error_state(
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings_instance.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# uncomment for debugging
|
||||
# task_logger.info(f"_should_index: "
|
||||
@@ -436,24 +388,24 @@ def should_index(
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
if last_index_attempt:
|
||||
if last_index:
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
if last_index_attempt.status == IndexingStatus.SUCCESS:
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
# print(
|
||||
# f"Not indexing cc_pair={cc_pair.id}: FUTURE model with successful last index attempt={last_index.id}"
|
||||
# )
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is waiting to start
|
||||
if last_index_attempt.status == IndexingStatus.NOT_STARTED:
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
# print(
|
||||
# f"Not indexing cc_pair={cc_pair.id}: FUTURE model with NOT_STARTED last index attempt={last_index.id}"
|
||||
# )
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is running
|
||||
if last_index_attempt.status == IndexingStatus.IN_PROGRESS:
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
# print(
|
||||
# f"Not indexing cc_pair={cc_pair.id}: FUTURE model with IN_PROGRESS last index attempt={last_index.id}"
|
||||
# )
|
||||
@@ -487,27 +439,18 @@ def should_index(
|
||||
return True
|
||||
|
||||
# if no attempt has ever occurred, we should index regardless of refresh_freq
|
||||
if not last_index_attempt:
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
# print(f"Not indexing cc_pair={cc_pair.id}: refresh_freq is None")
|
||||
return False
|
||||
|
||||
# if in the "initial" phase, we should always try and kick-off indexing
|
||||
# as soon as possible if there is no ongoing attempt. In other words,
|
||||
# no delay UNLESS we're repeatedly failing to index.
|
||||
if (
|
||||
cc_pair.status == ConnectorCredentialPairStatus.INITIAL_INDEXING
|
||||
and not all_recent_errored
|
||||
):
|
||||
return True
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
time_since_index = current_db_time - last_index_attempt.time_updated
|
||||
time_since_index = current_db_time - last_index.time_updated
|
||||
if time_since_index.total_seconds() < connector.refresh_freq:
|
||||
# print(
|
||||
# f"Not indexing cc_pair={cc_pair.id}: Last index attempt={last_index_attempt.id} "
|
||||
# f"Not indexing cc_pair={cc_pair.id}: Last index attempt={last_index.id} "
|
||||
# f"too recent ({time_since_index.total_seconds()}s < {connector.refresh_freq}s)"
|
||||
# )
|
||||
return False
|
||||
|
||||
@@ -11,8 +11,6 @@ from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import run_dc_graph
|
||||
from onyx.agents.agent_search.run_graph import run_kb_graph
|
||||
from onyx.agents.agent_search.run_graph import run_dc_graph
|
||||
from onyx.agents.agent_search.run_graph import run_main_graph
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
@@ -147,21 +145,13 @@ class Answer:
|
||||
|
||||
if self.graph_config.behavior.use_agentic_search:
|
||||
run_langgraph = run_main_graph
|
||||
elif self.graph_config.inputs.search_request.persona:
|
||||
if (
|
||||
elif (
|
||||
self.graph_config.inputs.search_request.persona
|
||||
and self.graph_config.inputs.search_request.persona.description.startswith(
|
||||
"DivCon Beta Agent"
|
||||
"DivCon Beta Agent"
|
||||
)
|
||||
):
|
||||
run_langgraph = run_dc_graph
|
||||
elif self.graph_config.inputs.search_request.persona.name.startswith(
|
||||
"KG Dev"
|
||||
):
|
||||
run_langgraph = run_kb_graph
|
||||
else:
|
||||
run_langgraph = run_basic_graph
|
||||
|
||||
):
|
||||
run_langgraph = run_dc_graph
|
||||
else:
|
||||
run_langgraph = run_basic_graph
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
@@ -101,39 +100,6 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
)
|
||||
|
||||
|
||||
def inference_section_from_llm_doc(llm_doc: LlmDoc) -> InferenceSection:
|
||||
# Create a center chunk first
|
||||
center_chunk = InferenceChunk(
|
||||
document_id=llm_doc.document_id,
|
||||
chunk_id=0, # Default to 0 since LlmDoc doesn't have this info
|
||||
content=llm_doc.content,
|
||||
blurb=llm_doc.blurb,
|
||||
semantic_identifier=llm_doc.semantic_identifier,
|
||||
source_type=llm_doc.source_type,
|
||||
metadata=llm_doc.metadata,
|
||||
updated_at=llm_doc.updated_at,
|
||||
source_links=llm_doc.source_links or {},
|
||||
match_highlights=llm_doc.match_highlights or [],
|
||||
section_continuation=False,
|
||||
image_file_name=None,
|
||||
title=None,
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=None,
|
||||
hidden=False,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
)
|
||||
|
||||
# Create InferenceSection with the center chunk
|
||||
# Since we don't have access to the original chunks, we'll use an empty list
|
||||
return InferenceSection(
|
||||
center_chunk=center_chunk,
|
||||
chunks=[], # Original surrounding chunks are not available in LlmDoc
|
||||
combined_content=llm_doc.content,
|
||||
)
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
|
||||
@@ -2,11 +2,9 @@ import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
from typing import Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -84,8 +82,6 @@ from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.milestone import check_multi_assistant_milestone
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
from onyx.db.milestone import update_user_assistant_milestone
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
@@ -100,8 +96,6 @@ from onyx.file_store.utils import load_all_chat_files
|
||||
from onyx.file_store.utils import load_all_user_file_files
|
||||
from onyx.file_store.utils import load_all_user_files
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.kg.clustering.clustering import kg_clustering
|
||||
from onyx.kg.extractions.extraction_processing import kg_extraction
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
@@ -165,25 +159,6 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
COMMON_TOOL_RESPONSE_TYPES = {
|
||||
"image": ChatFileType.IMAGE,
|
||||
"csv": ChatFileType.CSV,
|
||||
}
|
||||
|
||||
|
||||
class PartialResponse(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
message: str,
|
||||
rephrased_query: str | None,
|
||||
reference_docs: list[DbSearchDoc] | None,
|
||||
files: list[FileDescriptor],
|
||||
token_count: int,
|
||||
citations: dict[int, int] | None,
|
||||
error: str | None,
|
||||
tool_call: ToolCall | None,
|
||||
) -> ChatMessage: ...
|
||||
|
||||
|
||||
def _translate_citations(
|
||||
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
|
||||
@@ -236,25 +211,25 @@ def _handle_search_tool_response_summary(
|
||||
reference_db_search_docs = selected_search_docs
|
||||
|
||||
doc_ids = {doc.id for doc in reference_db_search_docs}
|
||||
if user_files is not None and loaded_user_files is not None:
|
||||
if user_files is not None:
|
||||
for user_file in user_files:
|
||||
if user_file.id in doc_ids:
|
||||
continue
|
||||
|
||||
associated_chat_file = next(
|
||||
(
|
||||
file
|
||||
for file in loaded_user_files
|
||||
if file.file_id == str(user_file.file_id)
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Use create_search_doc_from_user_file to properly add the document to the database
|
||||
if associated_chat_file is not None:
|
||||
db_doc = create_search_doc_from_user_file(
|
||||
user_file, associated_chat_file, db_session
|
||||
)
|
||||
reference_db_search_docs.append(db_doc)
|
||||
if user_file.id not in doc_ids:
|
||||
associated_chat_file = None
|
||||
if loaded_user_files is not None:
|
||||
associated_chat_file = next(
|
||||
(
|
||||
file
|
||||
for file in loaded_user_files
|
||||
if file.file_id == str(user_file.file_id)
|
||||
),
|
||||
None,
|
||||
)
|
||||
# Use create_search_doc_from_user_file to properly add the document to the database
|
||||
if associated_chat_file is not None:
|
||||
db_doc = create_search_doc_from_user_file(
|
||||
user_file, associated_chat_file, db_session
|
||||
)
|
||||
reference_db_search_docs.append(db_doc)
|
||||
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
@@ -382,86 +357,6 @@ def _get_force_search_settings(
|
||||
)
|
||||
|
||||
|
||||
def _get_user_knowledge_files(
|
||||
info: AnswerPostInfo,
|
||||
user_files: list[InMemoryChatFile],
|
||||
file_id_to_user_file: dict[str, InMemoryChatFile],
|
||||
) -> Generator[UserKnowledgeFilePacket, None, None]:
|
||||
if not info.qa_docs_response:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
|
||||
)
|
||||
|
||||
# Extract document order from search results
|
||||
doc_order = []
|
||||
for doc in info.qa_docs_response.top_documents:
|
||||
doc_id = doc.document_id
|
||||
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
|
||||
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
|
||||
if file_id in file_id_to_user_file:
|
||||
doc_order.append(file_id)
|
||||
|
||||
logger.info(f"ORDERING: Found {len(doc_order)} files from search results")
|
||||
|
||||
# Add any files that weren't in search results at the end
|
||||
missing_files = [
|
||||
f_id for f_id in file_id_to_user_file.keys() if f_id not in doc_order
|
||||
]
|
||||
|
||||
missing_files.extend(doc_order)
|
||||
doc_order = missing_files
|
||||
|
||||
logger.info(f"ORDERING: Added {len(missing_files)} missing files to the end")
|
||||
|
||||
# Reorder user files based on search results
|
||||
ordered_user_files = [
|
||||
file_id_to_user_file[f_id] for f_id in doc_order if f_id in file_id_to_user_file
|
||||
]
|
||||
|
||||
yield UserKnowledgeFilePacket(
|
||||
user_files=[
|
||||
FileDescriptor(
|
||||
id=str(file.file_id),
|
||||
type=ChatFileType.USER_KNOWLEDGE,
|
||||
)
|
||||
for file in ordered_user_files
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_persona_for_chat_session(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
default_persona: Persona,
|
||||
) -> Persona:
|
||||
if new_msg_req.alternate_assistant_id is not None:
|
||||
# Allows users to specify a temporary persona (assistant) in the chat session
|
||||
# this takes highest priority since it's user specified
|
||||
persona = get_persona_by_id(
|
||||
new_msg_req.alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
elif new_msg_req.persona_override_config:
|
||||
# Certain endpoints allow users to specify arbitrary persona settings
|
||||
# this should never conflict with the alternate_assistant_id
|
||||
persona = create_temporary_persona(
|
||||
db_session=db_session,
|
||||
persona_config=new_msg_req.persona_override_config,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
persona = default_persona
|
||||
|
||||
if not persona:
|
||||
raise RuntimeError("No persona specified or found for chat session")
|
||||
return persona
|
||||
|
||||
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
@@ -483,149 +378,6 @@ ChatPacket = (
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
|
||||
def _process_tool_response(
|
||||
packet: ToolResponse,
|
||||
db_session: Session,
|
||||
selected_db_search_docs: list[DbSearchDoc] | None,
|
||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo],
|
||||
retrieval_options: RetrievalDetails | None,
|
||||
user_file_files: list[UserFile] | None,
|
||||
user_files: list[InMemoryChatFile] | None,
|
||||
file_id_to_user_file: dict[str, InMemoryChatFile],
|
||||
search_for_ordering_only: bool,
|
||||
) -> Generator[ChatPacket, None, dict[SubQuestionKey, AnswerPostInfo]]:
|
||||
level, level_question_num = (
|
||||
(packet.level, packet.level_question_num)
|
||||
if isinstance(packet, ExtendedToolResponse)
|
||||
else BASIC_KEY
|
||||
)
|
||||
|
||||
assert level is not None
|
||||
assert level_question_num is not None
|
||||
info = info_by_subq[SubQuestionKey(level=level, question_num=level_question_num)]
|
||||
|
||||
# Skip LLM relevance processing entirely for ordering-only mode
|
||||
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
logger.info(
|
||||
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
|
||||
)
|
||||
# Skip this packet entirely since it would trigger LLM processing
|
||||
return info_by_subq
|
||||
|
||||
# TODO: don't need to dedupe here when we do it in agent flow
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Fast path: Skipping document deduplication for ordering-only mode"
|
||||
)
|
||||
|
||||
(
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
info.dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
# Skip deduping completely for ordering-only mode to save time
|
||||
dedupe_docs=bool(
|
||||
not search_for_ordering_only
|
||||
and retrieval_options
|
||||
and retrieval_options.dedupe_docs
|
||||
),
|
||||
user_files=user_file_files if search_for_ordering_only else [],
|
||||
loaded_user_files=(user_files if search_for_ordering_only else []),
|
||||
)
|
||||
|
||||
# If we're using search just for ordering user files
|
||||
if search_for_ordering_only and user_files:
|
||||
yield from _get_user_knowledge_files(
|
||||
info=info,
|
||||
user_files=user_files,
|
||||
file_id_to_user_file=file_id_to_user_file,
|
||||
)
|
||||
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Performance: Skipping relevance filtering for ordering-only mode"
|
||||
)
|
||||
return info_by_subq
|
||||
|
||||
if info.reference_db_search_docs is None:
|
||||
logger.warning("No reference docs found for relevance filtering")
|
||||
return info_by_subq
|
||||
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in info.reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if info.dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=info.reference_db_search_docs,
|
||||
dropped_indices=info.dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(llm_selected_doc_indices=llm_indices)
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(final_context_docs=packet.response)
|
||||
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(list[ImageGenerationResponse], packet.response)
|
||||
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in img_generation_response if img.url],
|
||||
base64_files=[
|
||||
img.image_data for img in img_generation_response if img.image_data
|
||||
],
|
||||
)
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
response_type = custom_tool_response.response_type
|
||||
if response_type in COMMON_TOOL_RESPONSE_TYPES:
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
file_type = COMMON_TOOL_RESPONSE_TYPES[response_type]
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(id=str(file_id), type=file_type)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
|
||||
else:
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
return info_by_subq
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
@@ -666,20 +418,10 @@ def stream_chat_message_objects(
|
||||
|
||||
llm: LLM
|
||||
|
||||
index_str = "danswer_chunk_text_embedding_3_small"
|
||||
|
||||
if new_msg_req.message == "ee":
|
||||
kg_extraction(tenant_id, index_str)
|
||||
|
||||
raise Exception("Extractions done")
|
||||
|
||||
elif new_msg_req.message == "cc":
|
||||
kg_clustering(tenant_id, index_str)
|
||||
raise Exception("Clustering done")
|
||||
|
||||
try:
|
||||
# Move these variables inside the try block
|
||||
file_id_to_user_file = {}
|
||||
ordered_user_files = None
|
||||
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
@@ -694,19 +436,35 @@ def stream_chat_message_objects(
|
||||
parent_id = new_msg_req.parent_message_id
|
||||
reference_doc_ids = new_msg_req.search_doc_ids
|
||||
retrieval_options = new_msg_req.retrieval_options
|
||||
new_msg_req.alternate_assistant_id
|
||||
alternate_assistant_id = new_msg_req.alternate_assistant_id
|
||||
|
||||
# permanent "log" store, used primarily for debugging
|
||||
long_term_logger = LongTermLogger(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
||||
)
|
||||
|
||||
persona = _get_persona_for_chat_session(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
default_persona=chat_session.persona,
|
||||
)
|
||||
if alternate_assistant_id is not None:
|
||||
# Allows users to specify a temporary persona (assistant) in the chat session
|
||||
# this takes highest priority since it's user specified
|
||||
persona = get_persona_by_id(
|
||||
alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
elif new_msg_req.persona_override_config:
|
||||
# Certain endpoints allow users to specify arbitrary persona settings
|
||||
# this should never conflict with the alternate_assistant_id
|
||||
persona = persona = create_temporary_persona(
|
||||
db_session=db_session,
|
||||
persona_config=new_msg_req.persona_override_config,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
persona = chat_session.persona
|
||||
|
||||
if not persona:
|
||||
raise RuntimeError("No persona specified or found for chat session")
|
||||
|
||||
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
|
||||
user=user,
|
||||
@@ -988,42 +746,31 @@ def stream_chat_message_objects(
|
||||
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
|
||||
)
|
||||
|
||||
def create_response(
|
||||
message: str,
|
||||
rephrased_query: str | None,
|
||||
reference_docs: list[DbSearchDoc] | None,
|
||||
files: list[FileDescriptor],
|
||||
token_count: int,
|
||||
citations: dict[int, int] | None,
|
||||
error: str | None,
|
||||
tool_call: ToolCall | None,
|
||||
) -> ChatMessage:
|
||||
return create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=(
|
||||
final_msg
|
||||
if existing_assistant_message_id is None
|
||||
else parent_message
|
||||
),
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
message=message,
|
||||
rephrased_query=rephrased_query,
|
||||
token_count=token_count,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
error=error,
|
||||
reference_docs=reference_docs,
|
||||
files=files,
|
||||
citations=citations,
|
||||
tool_call=tool_call,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
is_agentic=new_msg_req.use_agentic_search,
|
||||
)
|
||||
|
||||
partial_response = create_response
|
||||
# Cannot determine these without the LLM step or breaking out early
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
# if we're using an existing assistant message, then this will just be an
|
||||
# update operation, in which case the parent should be the parent of
|
||||
# the latest. If we're creating a new assistant message, then the parent
|
||||
# should be the latest message (latest user message)
|
||||
parent_message=(
|
||||
final_msg if existing_assistant_message_id is None else parent_message
|
||||
),
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
# message=,
|
||||
# rephrased_query=,
|
||||
# token_count=,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
# error=,
|
||||
# reference_docs=,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
is_agentic=new_msg_req.use_agentic_search,
|
||||
)
|
||||
|
||||
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
||||
if new_msg_req.persona_override_config:
|
||||
@@ -1294,23 +1041,220 @@ def stream_chat_message_objects(
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
)
|
||||
|
||||
# reference_db_search_docs = None
|
||||
# qa_docs_response = None
|
||||
# # any files to associate with the AI message e.g. dall-e generated images
|
||||
# ai_message_files = []
|
||||
# dropped_indices = None
|
||||
# tool_result = None
|
||||
|
||||
# TODO: different channels for stored info when it's coming from the agent flow
|
||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
|
||||
lambda: AnswerPostInfo(ai_message_files=[])
|
||||
)
|
||||
refined_answer_improvement = True
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
info_by_subq = yield from _process_tool_response(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_db_search_docs=selected_db_search_docs,
|
||||
info_by_subq=info_by_subq,
|
||||
retrieval_options=retrieval_options,
|
||||
user_file_files=user_file_files,
|
||||
user_files=user_files,
|
||||
file_id_to_user_file=file_id_to_user_file,
|
||||
search_for_ordering_only=search_for_ordering_only,
|
||||
level, level_question_num = (
|
||||
(packet.level, packet.level_question_num)
|
||||
if isinstance(packet, ExtendedToolResponse)
|
||||
else BASIC_KEY
|
||||
)
|
||||
assert level is not None
|
||||
assert level_question_num is not None
|
||||
info = info_by_subq[
|
||||
SubQuestionKey(level=level, question_num=level_question_num)
|
||||
]
|
||||
|
||||
# Skip LLM relevance processing entirely for ordering-only mode
|
||||
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
logger.info(
|
||||
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
|
||||
)
|
||||
# Skip this packet entirely since it would trigger LLM processing
|
||||
continue
|
||||
|
||||
# TODO: don't need to dedupe here when we do it in agent flow
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Fast path: Skipping document deduplication for ordering-only mode"
|
||||
)
|
||||
|
||||
(
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
info.dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
# Skip deduping completely for ordering-only mode to save time
|
||||
dedupe_docs=bool(
|
||||
not search_for_ordering_only
|
||||
and retrieval_options
|
||||
and retrieval_options.dedupe_docs
|
||||
),
|
||||
user_files=user_file_files if search_for_ordering_only else [],
|
||||
loaded_user_files=(
|
||||
user_files if search_for_ordering_only else []
|
||||
),
|
||||
)
|
||||
|
||||
# If we're using search just for ordering user files
|
||||
if (
|
||||
search_for_ordering_only
|
||||
and user_files
|
||||
and info.qa_docs_response
|
||||
):
|
||||
logger.info(
|
||||
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
|
||||
)
|
||||
|
||||
# Extract document order from search results
|
||||
doc_order = []
|
||||
for doc in info.qa_docs_response.top_documents:
|
||||
doc_id = doc.document_id
|
||||
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
|
||||
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
|
||||
if file_id in file_id_to_user_file:
|
||||
doc_order.append(file_id)
|
||||
|
||||
logger.info(
|
||||
f"ORDERING: Found {len(doc_order)} files from search results"
|
||||
)
|
||||
|
||||
# Add any files that weren't in search results at the end
|
||||
missing_files = [
|
||||
f_id
|
||||
for f_id in file_id_to_user_file.keys()
|
||||
if f_id not in doc_order
|
||||
]
|
||||
|
||||
missing_files.extend(doc_order)
|
||||
doc_order = missing_files
|
||||
|
||||
logger.info(
|
||||
f"ORDERING: Added {len(missing_files)} missing files to the end"
|
||||
)
|
||||
|
||||
# Reorder user files based on search results
|
||||
ordered_user_files = [
|
||||
file_id_to_user_file[f_id]
|
||||
for f_id in doc_order
|
||||
if f_id in file_id_to_user_file
|
||||
]
|
||||
|
||||
yield UserKnowledgeFilePacket(
|
||||
user_files=[
|
||||
FileDescriptor(
|
||||
id=str(file.file_id),
|
||||
type=ChatFileType.USER_KNOWLEDGE,
|
||||
)
|
||||
for file in ordered_user_files
|
||||
]
|
||||
)
|
||||
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Performance: Skipping relevance filtering for ordering-only mode"
|
||||
)
|
||||
continue
|
||||
|
||||
if info.reference_db_search_docs is None:
|
||||
logger.warning(
|
||||
"No reference docs found for relevance filtering"
|
||||
)
|
||||
continue
|
||||
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in info.reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if info.dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=info.reference_db_search_docs,
|
||||
dropped_indices=info.dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
)
|
||||
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
)
|
||||
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in img_generation_response if img.url],
|
||||
base64_files=[
|
||||
img.image_data
|
||||
for img in img_generation_response
|
||||
if img.image_data
|
||||
],
|
||||
)
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
|
||||
if (
|
||||
custom_tool_response.response_type == "image"
|
||||
or custom_tool_response.response_type == "csv"
|
||||
):
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=(
|
||||
ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
else ChatFileType.CSV
|
||||
),
|
||||
)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
else:
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason == StreamStopReason.FINISHED:
|
||||
@@ -1347,46 +1291,22 @@ def stream_chat_message_objects(
|
||||
|
||||
if isinstance(e, ToolCallException):
|
||||
yield StreamingError(error=error_msg, stack_trace=stack_trace)
|
||||
elif llm:
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
client_error_msg = client_error_msg.replace(
|
||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
stack_trace = stack_trace.replace(
|
||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
else:
|
||||
if llm:
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
error_msg = error_msg.replace(
|
||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
stack_trace = stack_trace.replace(
|
||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
|
||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
||||
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
yield from _post_llm_answer_processing(
|
||||
answer=answer,
|
||||
info_by_subq=info_by_subq,
|
||||
tool_dict=tool_dict,
|
||||
partial_response=partial_response,
|
||||
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
)
|
||||
|
||||
|
||||
def _post_llm_answer_processing(
|
||||
answer: Answer,
|
||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo],
|
||||
tool_dict: dict[int, list[Tool]],
|
||||
partial_response: PartialResponse,
|
||||
llm_tokenizer_encode_func: Callable[[str], list[int]],
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
refined_answer_improvement: bool | None,
|
||||
) -> Generator[ChatPacket, None, None]:
|
||||
"""
|
||||
Stores messages in the db and yields some final packets to the frontend
|
||||
"""
|
||||
# Post-LLM answer processing
|
||||
try:
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
|
||||
@@ -483,6 +483,14 @@ CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get(
|
||||
DISABLE_INDEX_UPDATE_ON_SWAP = (
|
||||
os.environ.get("DISABLE_INDEX_UPDATE_ON_SWAP", "").lower() == "true"
|
||||
)
|
||||
# Controls how many worker processes we spin up to index documents in the
|
||||
# background. This is useful for speeding up indexing, but does require a
|
||||
# fairly large amount of memory in order to increase substantially, since
|
||||
# each worker loads the embedding models into memory.
|
||||
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
|
||||
NUM_SECONDARY_INDEXING_WORKERS = int(
|
||||
os.environ.get("NUM_SECONDARY_INDEXING_WORKERS") or NUM_INDEXING_WORKERS
|
||||
)
|
||||
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
|
||||
ENABLE_MULTIPASS_INDEXING = (
|
||||
os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true"
|
||||
|
||||
@@ -96,9 +96,3 @@ BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
||||
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
|
||||
|
||||
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)
|
||||
|
||||
# Whether or not to use the semantic & keyword search expansions for Basic Search
|
||||
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH = (
|
||||
os.environ.get("USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH", "false").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
KG_OWN_EMAIL_DOMAINS: list[str] = json.loads(
|
||||
os.environ.get("KG_OWN_EMAIL_DOMAINS", "[]")
|
||||
) # must be list
|
||||
|
||||
KG_IGNORE_EMAIL_DOMAINS: list[str] = json.loads(
|
||||
os.environ.get("KG_IGNORE_EMAIL_DOMAINS", "[]")
|
||||
) # must be list
|
||||
|
||||
KG_OWN_COMPANY: str = os.environ.get("KG_OWN_COMPANY", "")
|
||||
@@ -101,7 +101,6 @@ class ConfluenceConnector(
|
||||
self.labels_to_skip = labels_to_skip
|
||||
self.timezone_offset = timezone_offset
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self._low_timeout_confluence_client: OnyxConfluence | None = None
|
||||
self._fetched_titles: set[str] = set()
|
||||
self.allow_images = False
|
||||
|
||||
@@ -157,12 +156,6 @@ class ConfluenceConnector(
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
return self._confluence_client
|
||||
|
||||
@property
|
||||
def low_timeout_confluence_client(self) -> OnyxConfluence:
|
||||
if self._low_timeout_confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
return self._low_timeout_confluence_client
|
||||
|
||||
def set_credentials_provider(
|
||||
self, credentials_provider: CredentialsProviderInterface
|
||||
) -> None:
|
||||
@@ -170,27 +163,13 @@ class ConfluenceConnector(
|
||||
|
||||
# raises exception if there's a problem
|
||||
confluence_client = OnyxConfluence(
|
||||
is_cloud=self.is_cloud,
|
||||
url=self.wiki_base,
|
||||
credentials_provider=credentials_provider,
|
||||
self.is_cloud, self.wiki_base, credentials_provider
|
||||
)
|
||||
confluence_client._probe_connection(**self.probe_kwargs)
|
||||
confluence_client._initialize_connection(**self.final_kwargs)
|
||||
|
||||
self._confluence_client = confluence_client
|
||||
|
||||
# create a low timeout confluence client for sync flows
|
||||
low_timeout_confluence_client = OnyxConfluence(
|
||||
is_cloud=self.is_cloud,
|
||||
url=self.wiki_base,
|
||||
credentials_provider=credentials_provider,
|
||||
timeout=3,
|
||||
)
|
||||
low_timeout_confluence_client._probe_connection(**self.probe_kwargs)
|
||||
low_timeout_confluence_client._initialize_connection(**self.final_kwargs)
|
||||
|
||||
self._low_timeout_confluence_client = low_timeout_confluence_client
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
raise NotImplementedError("Use set_credentials_provider with this connector.")
|
||||
|
||||
@@ -542,8 +521,11 @@ class ConfluenceConnector(
|
||||
yield doc_metadata_list
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self._confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence credentials not loaded.")
|
||||
|
||||
try:
|
||||
spaces = self.low_timeout_confluence_client.get_all_spaces(limit=1)
|
||||
spaces = self._confluence_client.get_all_spaces(limit=1)
|
||||
except HTTPError as e:
|
||||
status_code = e.response.status_code if e.response else None
|
||||
if status_code == 401:
|
||||
|
||||
@@ -72,14 +72,12 @@ class OnyxConfluence:
|
||||
|
||||
CREDENTIAL_PREFIX = "connector:confluence:credential"
|
||||
CREDENTIAL_TTL = 300 # 5 min
|
||||
PROBE_TIMEOUT = 5 # 5 seconds
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_cloud: bool,
|
||||
url: str,
|
||||
credentials_provider: CredentialsProviderInterface,
|
||||
timeout: int | None = None,
|
||||
) -> None:
|
||||
self._is_cloud = is_cloud
|
||||
self._url = url.rstrip("/")
|
||||
@@ -102,13 +100,11 @@ class OnyxConfluence:
|
||||
|
||||
self._kwargs: Any = None
|
||||
|
||||
self.shared_base_kwargs: dict[str, str | int | bool] = {
|
||||
self.shared_base_kwargs = {
|
||||
"api_version": "cloud" if is_cloud else "latest",
|
||||
"backoff_and_retry": True,
|
||||
"cloud": is_cloud,
|
||||
}
|
||||
if timeout:
|
||||
self.shared_base_kwargs["timeout"] = timeout
|
||||
|
||||
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
|
||||
"""credential_json - the current json credentials
|
||||
@@ -195,8 +191,6 @@ class OnyxConfluence:
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
|
||||
# add special timeout to make sure that we don't hang indefinitely
|
||||
merged_kwargs["timeout"] = self.PROBE_TIMEOUT
|
||||
|
||||
with self._credentials_provider:
|
||||
credentials, _ = self._renew_credentials()
|
||||
|
||||
@@ -18,17 +18,11 @@ from onyx.indexing.models import IndexingSetting
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
MAX_METRICS_CONTENT = (
|
||||
200 # Just need enough characters to identify where in the doc the chunk is
|
||||
)
|
||||
|
||||
|
||||
class QueryExpansions(BaseModel):
|
||||
keywords_expansions: list[str] | None = None
|
||||
semantic_expansions: list[str] | None = None
|
||||
|
||||
|
||||
class RerankingDetails(BaseModel):
|
||||
# If model is None (or num_rerank is 0), then reranking is turned off
|
||||
rerank_model_name: str | None
|
||||
@@ -113,9 +107,6 @@ class BaseFilters(BaseModel):
|
||||
tags: list[Tag] | None = None
|
||||
user_file_ids: list[int] | None = None
|
||||
user_folder_ids: list[int] | None = None
|
||||
kg_entities: list[str] | None = None
|
||||
kg_relationships: list[str] | None = None
|
||||
kg_terms: list[str] | None = None
|
||||
|
||||
|
||||
class IndexFilters(BaseFilters):
|
||||
@@ -148,8 +139,6 @@ class ChunkContext(BaseModel):
|
||||
class SearchRequest(ChunkContext):
|
||||
query: str
|
||||
|
||||
expanded_queries: QueryExpansions | None = None
|
||||
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
|
||||
human_selected_filters: BaseFilters | None = None
|
||||
@@ -198,8 +187,6 @@ class SearchQuery(ChunkContext):
|
||||
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
|
||||
expanded_queries: QueryExpansions | None = None
|
||||
|
||||
|
||||
class RetrievalDetails(ChunkContext):
|
||||
# Use LLM to determine whether to do a retrieval or only rely on existing history
|
||||
|
||||
@@ -68,7 +68,6 @@ class SearchPipeline:
|
||||
):
|
||||
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
|
||||
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
|
||||
|
||||
self.search_request = search_request
|
||||
self.user = user
|
||||
self.llm = llm
|
||||
|
||||
@@ -20,7 +20,7 @@ from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.preprocessing.access_filters import (
|
||||
build_access_filters_for_user,
|
||||
)
|
||||
from onyx.context.search.utils import (
|
||||
from onyx.context.search.retrieval.search_runner import (
|
||||
remove_stop_words_and_punctuation,
|
||||
)
|
||||
from onyx.db.models import User
|
||||
@@ -36,6 +36,7 @@ from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -182,9 +183,6 @@ def retrieval_preprocessing(
|
||||
tags=preset_filters.tags, # Tags are never auto-extracted
|
||||
access_control_list=user_acl_filters,
|
||||
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
|
||||
kg_entities=preset_filters.kg_entities,
|
||||
kg_relationships=preset_filters.kg_relationships,
|
||||
kg_terms=preset_filters.kg_terms,
|
||||
)
|
||||
|
||||
llm_evaluation_type = LLMEvaluationType.BASIC
|
||||
@@ -266,5 +264,4 @@ def retrieval_preprocessing(
|
||||
chunks_below=chunks_below,
|
||||
full_doc=search_request.full_doc,
|
||||
precomputed_query_embedding=search_request.precomputed_query_embedding,
|
||||
expanded_queries=search_request.expanded_queries,
|
||||
)
|
||||
|
||||
@@ -2,10 +2,10 @@ import string
|
||||
from collections.abc import Callable
|
||||
|
||||
import nltk # type:ignore
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
from nltk.tokenize import word_tokenize # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import ChunkMetric
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
@@ -15,8 +15,6 @@ from onyx.context.search.models import MAX_METRICS_CONTENT
|
||||
from onyx.context.search.models import RetrievalMetricsContainer
|
||||
from onyx.context.search.models import SearchQuery
|
||||
from onyx.context.search.postprocessing.postprocessing import cleanup_chunks
|
||||
from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA
|
||||
from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA_KEYWORD
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_multilingual_expansion
|
||||
@@ -29,9 +27,6 @@ from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import TimeoutThread
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
@@ -41,23 +36,6 @@ from shared_configs.model_server_models import Embedding
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _dedupe_chunks(
|
||||
chunks: list[InferenceChunkUncleaned],
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
used_chunks: dict[tuple[str, int], InferenceChunkUncleaned] = {}
|
||||
for chunk in chunks:
|
||||
key = (chunk.document_id, chunk.chunk_id)
|
||||
if key not in used_chunks:
|
||||
used_chunks[key] = chunk
|
||||
else:
|
||||
stored_chunk_score = used_chunks[key].score or 0
|
||||
this_chunk_score = chunk.score or 0
|
||||
if stored_chunk_score < this_chunk_score:
|
||||
used_chunks[key] = chunk
|
||||
|
||||
return list(used_chunks.values())
|
||||
|
||||
|
||||
def download_nltk_data() -> None:
|
||||
resources = {
|
||||
"stopwords": "corpora/stopwords",
|
||||
@@ -91,6 +69,22 @@ def lemmatize_text(keywords: list[str]) -> list[str]:
|
||||
# return keywords
|
||||
|
||||
|
||||
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
|
||||
try:
|
||||
# Re-tokenize using the NLTK tokenizer for better matching
|
||||
query = " ".join(keywords)
|
||||
stop_words = set(stopwords.words("english"))
|
||||
word_tokens = word_tokenize(query)
|
||||
text_trimmed = [
|
||||
word
|
||||
for word in word_tokens
|
||||
if (word.casefold() not in stop_words and word not in string.punctuation)
|
||||
]
|
||||
return text_trimmed or word_tokens
|
||||
except Exception:
|
||||
return keywords
|
||||
|
||||
|
||||
def combine_retrieval_results(
|
||||
chunk_sets: list[list[InferenceChunk]],
|
||||
) -> list[InferenceChunk]:
|
||||
@@ -129,20 +123,6 @@ def get_query_embedding(query: str, db_session: Session) -> Embedding:
|
||||
return query_embedding
|
||||
|
||||
|
||||
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
|
||||
return query_embedding
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def doc_index_retrieval(
|
||||
query: SearchQuery,
|
||||
@@ -159,113 +139,17 @@ def doc_index_retrieval(
|
||||
query.query, db_session
|
||||
)
|
||||
|
||||
keyword_embeddings_thread: TimeoutThread[list[Embedding]] | None = None
|
||||
semantic_embeddings_thread: TimeoutThread[list[Embedding]] | None = None
|
||||
top_base_chunks_thread: TimeoutThread[list[InferenceChunkUncleaned]] | None = None
|
||||
|
||||
top_semantic_chunks_thread: TimeoutThread[list[InferenceChunkUncleaned]] | None = (
|
||||
None
|
||||
top_chunks = document_index.hybrid_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
final_keywords=query.processed_keywords,
|
||||
filters=query.filters,
|
||||
hybrid_alpha=query.hybrid_alpha,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
offset=query.offset,
|
||||
)
|
||||
|
||||
keyword_embeddings: list[Embedding] | None = None
|
||||
semantic_embeddings: list[Embedding] | None = None
|
||||
|
||||
top_semantic_chunks: list[InferenceChunkUncleaned] | None = None
|
||||
|
||||
# original retrieveal method
|
||||
top_base_chunks_thread = run_in_background(
|
||||
document_index.hybrid_retrieval,
|
||||
query.query,
|
||||
query_embedding,
|
||||
query.processed_keywords,
|
||||
query.filters,
|
||||
query.hybrid_alpha,
|
||||
query.recency_bias_multiplier,
|
||||
query.num_hits,
|
||||
"semantic",
|
||||
query.offset,
|
||||
)
|
||||
|
||||
if (
|
||||
query.expanded_queries
|
||||
and query.expanded_queries.keywords_expansions
|
||||
and query.expanded_queries.semantic_expansions
|
||||
):
|
||||
|
||||
keyword_embeddings_thread = run_in_background(
|
||||
get_query_embeddings,
|
||||
query.expanded_queries.keywords_expansions,
|
||||
db_session,
|
||||
)
|
||||
|
||||
if query.search_type == SearchType.SEMANTIC:
|
||||
semantic_embeddings_thread = run_in_background(
|
||||
get_query_embeddings,
|
||||
query.expanded_queries.semantic_expansions,
|
||||
db_session,
|
||||
)
|
||||
|
||||
keyword_embeddings = wait_on_background(keyword_embeddings_thread)
|
||||
if query.search_type == SearchType.SEMANTIC:
|
||||
assert semantic_embeddings_thread is not None
|
||||
semantic_embeddings = wait_on_background(semantic_embeddings_thread)
|
||||
|
||||
# Use original query embedding for keyword retrieval embedding
|
||||
keyword_embeddings = [query_embedding]
|
||||
|
||||
# Note: we generally prepped earlier for multiple expansions, but for now we only use one.
|
||||
top_keyword_chunks_thread = run_in_background(
|
||||
document_index.hybrid_retrieval,
|
||||
query.expanded_queries.keywords_expansions[0],
|
||||
keyword_embeddings[0],
|
||||
query.processed_keywords,
|
||||
query.filters,
|
||||
HYBRID_ALPHA_KEYWORD,
|
||||
query.recency_bias_multiplier,
|
||||
query.num_hits,
|
||||
QueryExpansionType.KEYWORD,
|
||||
query.offset,
|
||||
)
|
||||
|
||||
if query.search_type == SearchType.SEMANTIC:
|
||||
assert semantic_embeddings is not None
|
||||
|
||||
top_semantic_chunks_thread = run_in_background(
|
||||
document_index.hybrid_retrieval,
|
||||
query.expanded_queries.semantic_expansions[0],
|
||||
semantic_embeddings[0],
|
||||
query.processed_keywords,
|
||||
query.filters,
|
||||
HYBRID_ALPHA,
|
||||
query.recency_bias_multiplier,
|
||||
query.num_hits,
|
||||
QueryExpansionType.SEMANTIC,
|
||||
query.offset,
|
||||
)
|
||||
|
||||
top_base_chunks = wait_on_background(top_base_chunks_thread)
|
||||
|
||||
top_keyword_chunks = wait_on_background(top_keyword_chunks_thread)
|
||||
|
||||
if query.search_type == SearchType.SEMANTIC:
|
||||
assert top_semantic_chunks_thread is not None
|
||||
top_semantic_chunks = wait_on_background(top_semantic_chunks_thread)
|
||||
|
||||
all_top_chunks = top_base_chunks + top_keyword_chunks
|
||||
|
||||
# use all three retrieval methods to retrieve top chunks
|
||||
|
||||
if query.search_type == SearchType.SEMANTIC and top_semantic_chunks is not None:
|
||||
|
||||
all_top_chunks += top_semantic_chunks
|
||||
|
||||
top_chunks = _dedupe_chunks(all_top_chunks)
|
||||
|
||||
else:
|
||||
|
||||
top_base_chunks = wait_on_background(top_base_chunks_thread)
|
||||
top_chunks = _dedupe_chunks(top_base_chunks)
|
||||
|
||||
retrieval_requests: list[VespaChunkRequest] = []
|
||||
normal_chunks: list[InferenceChunkUncleaned] = []
|
||||
referenced_chunk_scores: dict[tuple[str, int], float] = {}
|
||||
@@ -349,8 +233,6 @@ def retrieve_chunks(
|
||||
) -> list[InferenceChunk]:
|
||||
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
|
||||
|
||||
logger.info(f"RETRIEVAL CHUNKS query: {query}")
|
||||
|
||||
multilingual_expansion = get_multilingual_expansion(db_session)
|
||||
# Don't do query expansion on complex queries, rephrasings likely would not work well
|
||||
if not multilingual_expansion or "\n" in query.query or "\r" in query.query:
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import string
|
||||
from collections.abc import Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
from nltk.tokenize import word_tokenize # type:ignore
|
||||
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
@@ -140,19 +136,3 @@ def chunks_or_sections_to_search_docs(
|
||||
]
|
||||
|
||||
return search_docs
|
||||
|
||||
|
||||
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
|
||||
try:
|
||||
# Re-tokenize using the NLTK tokenizer for better matching
|
||||
query = " ".join(keywords)
|
||||
stop_words = set(stopwords.words("english"))
|
||||
word_tokens = word_tokenize(query)
|
||||
text_trimmed = [
|
||||
word
|
||||
for word in word_tokens
|
||||
if (word.casefold() not in stop_words and word not in string.punctuation)
|
||||
]
|
||||
return text_trimmed or word_tokens
|
||||
except Exception:
|
||||
return keywords
|
||||
|
||||
@@ -56,8 +56,7 @@ def get_total_users_count(db_session: Session) -> int:
|
||||
|
||||
async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
async with get_async_session_with_tenant() as session:
|
||||
count_stmt = func.count(User.id) # type: ignore
|
||||
stmt = select(count_stmt)
|
||||
stmt = select(func.count(User.id))
|
||||
if only_admin_users:
|
||||
stmt = stmt.where(User.role == UserRole.ADMIN)
|
||||
result = await session.execute(stmt)
|
||||
|
||||
@@ -334,21 +334,3 @@ def mark_ccpair_with_indexing_trigger(
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_unprocessed_connector_ids(db_session: Session) -> list[int]:
|
||||
"""
|
||||
Retrieves a list of connector IDs that have not been KG processed for a given tenant.
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
Returns:
|
||||
list[int]: List of connector IDs that have enabled KG extraction but have unprocessed documents
|
||||
"""
|
||||
try:
|
||||
stmt = select(Connector.id).where(Connector.kg_extraction_enabled)
|
||||
result = db_session.execute(stmt)
|
||||
return [row[0] for row in result.fetchall()]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching unprocessed connector IDs: {str(e)}")
|
||||
raise e
|
||||
|
||||
@@ -7,7 +7,6 @@ from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
@@ -395,20 +394,6 @@ def update_connector_credential_pair(
|
||||
)
|
||||
|
||||
|
||||
def set_cc_pair_repeated_error_state(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
in_repeated_error_state: bool,
|
||||
) -> None:
|
||||
stmt = (
|
||||
update(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
.values(in_repeated_error_state=in_repeated_error_state)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_connector_credential_pair__no_commit(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
@@ -472,7 +457,7 @@ def add_credential_to_connector(
|
||||
access_type: AccessType,
|
||||
groups: list[int] | None,
|
||||
auto_sync_options: dict | None = None,
|
||||
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.SCHEDULED,
|
||||
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.ACTIVE,
|
||||
last_successful_index_time: datetime | None = None,
|
||||
seeding_flow: bool = False,
|
||||
is_user_file: bool = False,
|
||||
|
||||
@@ -23,8 +23,6 @@ from sqlalchemy.sql.expression import null
|
||||
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
@@ -379,7 +377,6 @@ def upsert_documents(
|
||||
last_modified=datetime.now(timezone.utc),
|
||||
primary_owners=doc.primary_owners,
|
||||
secondary_owners=doc.secondary_owners,
|
||||
kg_processed=False,
|
||||
)
|
||||
)
|
||||
for doc in seen_documents.values()
|
||||
@@ -846,213 +843,3 @@ def fetch_chunk_count_for_document(
|
||||
) -> int | None:
|
||||
stmt = select(DbDocument.chunk_count).where(DbDocument.id == document_id)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_unprocessed_kg_documents_for_connector(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
batch_size: int = 100,
|
||||
) -> Generator[DbDocument, None, None]:
|
||||
"""
|
||||
Retrieves all documents associated with a connector that have not yet been processed
|
||||
for knowledge graph extraction. Uses a generator pattern to handle large result sets.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
connector_id (int): The ID of the connector to check
|
||||
batch_size (int): Number of documents to fetch per batch, defaults to 100
|
||||
|
||||
Yields:
|
||||
DbDocument: Documents that haven't been KG processed, one at a time
|
||||
"""
|
||||
offset = 0
|
||||
while True:
|
||||
stmt = (
|
||||
select(DbDocument)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DbDocument.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
or_(
|
||||
DocumentByConnectorCredentialPair.has_been_kg_processed.is_(
|
||||
None
|
||||
),
|
||||
DocumentByConnectorCredentialPair.has_been_kg_processed.is_(
|
||||
False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
.limit(batch_size)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
batch = list(db_session.scalars(stmt).all())
|
||||
if not batch:
|
||||
break
|
||||
|
||||
for document in batch:
|
||||
yield document
|
||||
|
||||
offset += batch_size
|
||||
|
||||
|
||||
def get_kg_processed_document_ids(db_session: Session) -> list[str]:
|
||||
"""
|
||||
Retrieves all document IDs where kg_processed is True.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
|
||||
Returns:
|
||||
list[str]: List of document IDs that have been KG processed
|
||||
"""
|
||||
stmt = select(DbDocument.id).where(DbDocument.kg_processed.is_(True))
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def update_document_kg_info(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
kg_processed: bool,
|
||||
kg_data: dict,
|
||||
) -> None:
|
||||
"""Updates the knowledge graph related information for a document.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
document_id (str): The ID of the document to update
|
||||
kg_processed (bool): Whether the document has been processed for KG extraction
|
||||
kg_data (dict): Dictionary containing KG data with 'entities', 'relationships', and 'terms' keys
|
||||
|
||||
Raises:
|
||||
ValueError: If the document with the given ID is not found
|
||||
"""
|
||||
stmt = (
|
||||
update(DbDocument)
|
||||
.where(DbDocument.id == document_id)
|
||||
.values(
|
||||
kg_processed=kg_processed,
|
||||
kg_data=kg_data,
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def get_document_kg_info(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
) -> tuple[bool, dict] | None:
|
||||
"""Retrieves the knowledge graph processing status and data for a document.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
document_id (str): The ID of the document to query
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[bool, dict]]: A tuple containing:
|
||||
- bool: Whether the document has been KG processed
|
||||
- dict: The KG data containing 'entities', 'relationships', and 'terms'
|
||||
Returns None if the document is not found
|
||||
"""
|
||||
stmt = select(DbDocument.kg_processed, DbDocument.kg_data).where(
|
||||
DbDocument.id == document_id
|
||||
)
|
||||
result = db_session.execute(stmt).one_or_none()
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return result.kg_processed, result.kg_data or {}
|
||||
|
||||
|
||||
def get_all_kg_processed_documents_info(
|
||||
db_session: Session,
|
||||
) -> list[tuple[str, dict]]:
|
||||
"""Retrieves the knowledge graph data for all documents that have been processed.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, dict]]: A list of tuples containing:
|
||||
- str: The document ID
|
||||
- dict: The KG data containing 'entities', 'relationships', and 'terms'
|
||||
Only returns documents where kg_processed is True
|
||||
"""
|
||||
stmt = (
|
||||
select(DbDocument.id, DbDocument.kg_data)
|
||||
.where(DbDocument.kg_processed.is_(True))
|
||||
.order_by(DbDocument.id)
|
||||
)
|
||||
|
||||
results = db_session.execute(stmt).all()
|
||||
return [(str(doc_id), kg_data or {}) for doc_id, kg_data in results]
|
||||
|
||||
|
||||
def get_base_llm_doc_information(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> list[InferenceSection]:
|
||||
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
|
||||
results = db_session.execute(stmt).all()
|
||||
|
||||
inference_sections = []
|
||||
|
||||
for doc in results:
|
||||
bare_doc = doc[0]
|
||||
inference_section = InferenceSection(
|
||||
center_chunk=InferenceChunk(
|
||||
document_id=bare_doc.id,
|
||||
chunk_id=0,
|
||||
source_type=DocumentSource.NOT_APPLICABLE,
|
||||
semantic_identifier=bare_doc.semantic_id,
|
||||
title=None,
|
||||
boost=0,
|
||||
recency_bias=0,
|
||||
score=0,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
blurb="",
|
||||
content="",
|
||||
source_links=None,
|
||||
image_file_name=None,
|
||||
section_continuation=False,
|
||||
match_highlights=[],
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
updated_at=None,
|
||||
),
|
||||
chunks=[],
|
||||
combined_content="",
|
||||
)
|
||||
|
||||
inference_sections.append(inference_section)
|
||||
return inference_sections
|
||||
|
||||
|
||||
def get_document_updated_at(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> datetime | None:
|
||||
"""Retrieves the doc_updated_at timestamp for a given document ID.
|
||||
|
||||
Args:
|
||||
document_id (str): The ID of the document to query
|
||||
db_session (Session): The database session to use
|
||||
|
||||
Returns:
|
||||
Optional[datetime]: The doc_updated_at timestamp if found, None if document doesn't exist
|
||||
"""
|
||||
if len(document_id.split(":")) == 2:
|
||||
document_id = document_id.split(":")[1]
|
||||
elif len(document_id.split(":")) > 2:
|
||||
raise ValueError(f"Invalid document ID: {document_id}")
|
||||
else:
|
||||
pass
|
||||
|
||||
stmt = select(DbDocument.doc_updated_at).where(DbDocument.id == document_id)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
@@ -1,271 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import literal_column
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import KGEntity
|
||||
from onyx.db.models import KGEntityType
|
||||
|
||||
|
||||
def get_entity_types(
|
||||
db_session: Session,
|
||||
active: bool | None = True,
|
||||
) -> list[KGEntityType]:
|
||||
# Query the database for all distinct entity types
|
||||
|
||||
if active is None:
|
||||
return db_session.query(KGEntityType).order_by(KGEntityType.id_name).all()
|
||||
|
||||
else:
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.active == active)
|
||||
.order_by(KGEntityType.id_name)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def add_entity(
|
||||
db_session: Session,
|
||||
entity_type: str,
|
||||
name: str,
|
||||
document_id: str | None = None,
|
||||
cluster_count: int = 0,
|
||||
event_time: datetime | None = None,
|
||||
) -> "KGEntity | None":
|
||||
"""Add a new entity to the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity_type: Type of the entity (must match an existing KGEntityType)
|
||||
name: Name of the entity
|
||||
cluster_count: Number of clusters this entity has been found
|
||||
|
||||
Returns:
|
||||
KGEntity: The created entity
|
||||
"""
|
||||
entity_type = entity_type.upper()
|
||||
name = name.title()
|
||||
id_name = f"{entity_type}:{name}"
|
||||
|
||||
# Create new entity
|
||||
stmt = (
|
||||
pg_insert(KGEntity)
|
||||
.values(
|
||||
id_name=id_name,
|
||||
entity_type_id_name=entity_type,
|
||||
document_id=document_id,
|
||||
name=name,
|
||||
cluster_count=cluster_count,
|
||||
event_time=event_time,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["id_name"],
|
||||
set_=dict(
|
||||
# Direct numeric addition without text()
|
||||
cluster_count=KGEntity.cluster_count
|
||||
+ literal_column("EXCLUDED.cluster_count"),
|
||||
# Keep other fields updated as before
|
||||
entity_type_id_name=entity_type,
|
||||
document_id=document_id,
|
||||
name=name,
|
||||
event_time=event_time,
|
||||
),
|
||||
)
|
||||
.returning(KGEntity)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt).scalar()
|
||||
return result
|
||||
|
||||
|
||||
def get_kg_entity_by_document(db: Session, document_id: str) -> KGEntity | None:
|
||||
"""
|
||||
Check if a document_id exists in the kg_entities table and return its id_name if found.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
document_id: The document ID to search for
|
||||
|
||||
Returns:
|
||||
The id_name of the matching KGEntity if found, None otherwise
|
||||
"""
|
||||
query = select(KGEntity).where(KGEntity.document_id == document_id)
|
||||
result = db.execute(query).scalar()
|
||||
return result
|
||||
|
||||
|
||||
def get_ungrounded_entities(db_session: Session) -> List[KGEntity]:
|
||||
"""Get all entities whose entity type has grounding = 'UE' (ungrounded entities).
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntity objects belonging to ungrounded entity types
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntity)
|
||||
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
|
||||
.filter(KGEntityType.grounding == "UE")
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_grounded_entities(db_session: Session) -> List[KGEntity]:
|
||||
"""Get all entities whose entity type has grounding = 'UE' (ungrounded entities).
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntity objects belonging to ungrounded entity types
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntity)
|
||||
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
|
||||
.filter(KGEntityType.grounding == "GE")
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_determined_grounded_entity_types(db_session: Session) -> List[KGEntityType]:
|
||||
"""Get all entity types that have non-null ge_determine_instructions.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have ge_determine_instructions defined
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.ge_determine_instructions.isnot(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_entity_types_with_grounded_source_name(
|
||||
db_session: Session,
|
||||
) -> List[KGEntityType]:
|
||||
"""Get all entity types that have non-null grounded_source_name.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have grounded_source_name defined
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.grounded_source_name.isnot(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_entity_types_with_grounding_signature(
|
||||
db_session: Session,
|
||||
) -> List[KGEntityType]:
|
||||
"""Get all entity types that have non-null ge_grounding_signature.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have ge_grounding_signature defined
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.ge_grounding_signature.isnot(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_ge_entities_by_types(
|
||||
db_session: Session, entity_types: List[str]
|
||||
) -> List[KGEntity]:
|
||||
"""Get all entities matching an entity_type.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity_types: List of entity types to filter by
|
||||
|
||||
Returns:
|
||||
List of KGEntity objects belonging to the specified entity types
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntity)
|
||||
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
|
||||
.filter(KGEntity.entity_type_id_name.in_(entity_types))
|
||||
.filter(KGEntityType.grounding == "GE")
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def delete_entities_by_id_names(db_session: Session, id_names: list[str]) -> int:
|
||||
"""
|
||||
Delete entities from the database based on a list of id_names.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
id_names: List of entity id_names to delete
|
||||
|
||||
Returns:
|
||||
Number of entities deleted
|
||||
"""
|
||||
deleted_count = (
|
||||
db_session.query(KGEntity)
|
||||
.filter(KGEntity.id_name.in_(id_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
db_session.flush() # Flush to ensure deletion is processed
|
||||
return deleted_count
|
||||
|
||||
|
||||
def get_entities_for_types(
|
||||
db_session: Session, entity_types: List[str]
|
||||
) -> List[KGEntity]:
|
||||
"""Get all entities that belong to the specified entity types.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity_types: List of entity type id_names to filter by
|
||||
|
||||
Returns:
|
||||
List of KGEntity objects belonging to the specified entity types
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntity)
|
||||
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
|
||||
.filter(KGEntity.entity_type_id_name.in_(entity_types))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_entity_type_by_grounded_source_name(
|
||||
db_session: Session, grounded_source_name: str
|
||||
) -> KGEntityType | None:
|
||||
"""Get an entity type by its grounded_source_name and return it as a dictionary.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
grounded_source_name: The grounded_source_name of the entity to retrieve
|
||||
|
||||
Returns:
|
||||
Dictionary containing the entity's data with column names as keys,
|
||||
or None if the entity is not found
|
||||
"""
|
||||
entity_type = (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.grounded_source_name == grounded_source_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if entity_type is None:
|
||||
return None
|
||||
|
||||
return entity_type
|
||||
@@ -18,12 +18,6 @@ class IndexingStatus(str, PyEnum):
|
||||
}
|
||||
return self in terminal_states
|
||||
|
||||
def is_successful(self) -> bool:
|
||||
return (
|
||||
self == IndexingStatus.SUCCESS
|
||||
or self == IndexingStatus.COMPLETED_WITH_ERRORS
|
||||
)
|
||||
|
||||
|
||||
class IndexingMode(str, PyEnum):
|
||||
UPDATE = "update"
|
||||
@@ -79,19 +73,13 @@ class ChatSessionSharedStatus(str, PyEnum):
|
||||
|
||||
|
||||
class ConnectorCredentialPairStatus(str, PyEnum):
|
||||
SCHEDULED = "SCHEDULED"
|
||||
INITIAL_INDEXING = "INITIAL_INDEXING"
|
||||
ACTIVE = "ACTIVE"
|
||||
PAUSED = "PAUSED"
|
||||
DELETING = "DELETING"
|
||||
INVALID = "INVALID"
|
||||
|
||||
def is_active(self) -> bool:
|
||||
return (
|
||||
self == ConnectorCredentialPairStatus.ACTIVE
|
||||
or self == ConnectorCredentialPairStatus.SCHEDULED
|
||||
or self == ConnectorCredentialPairStatus.INITIAL_INDEXING
|
||||
)
|
||||
return self == ConnectorCredentialPairStatus.ACTIVE
|
||||
|
||||
|
||||
class AccessType(str, PyEnum):
|
||||
|
||||
@@ -59,7 +59,6 @@ def get_recent_completed_attempts_for_cc_pair(
|
||||
limit: int,
|
||||
db_session: Session,
|
||||
) -> list[IndexAttempt]:
|
||||
"""Most recent to least recent."""
|
||||
return (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
@@ -75,25 +74,6 @@ def get_recent_completed_attempts_for_cc_pair(
|
||||
)
|
||||
|
||||
|
||||
def get_recent_attempts_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
limit: int,
|
||||
db_session: Session,
|
||||
) -> list[IndexAttempt]:
|
||||
"""Most recent to least recent."""
|
||||
return (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair_id,
|
||||
IndexAttempt.search_settings_id == search_settings_id,
|
||||
)
|
||||
.order_by(IndexAttempt.time_updated.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_index_attempt(
|
||||
db_session: Session, index_attempt_id: int
|
||||
) -> IndexAttempt | None:
|
||||
|
||||
@@ -53,7 +53,6 @@ from onyx.db.enums import (
|
||||
SyncType,
|
||||
SyncStatus,
|
||||
)
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.constants import TokenRateLimitScope
|
||||
@@ -438,9 +437,6 @@ class ConnectorCredentialPair(Base):
|
||||
status: Mapped[ConnectorCredentialPairStatus] = mapped_column(
|
||||
Enum(ConnectorCredentialPairStatus, native_enum=False), nullable=False
|
||||
)
|
||||
# this is separate from the `status` above, since a connector can be `INITIAL_INDEXING`, `ACTIVE`,
|
||||
# or `PAUSED` and still be in a repeated error state.
|
||||
in_repeated_error_state: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
connector_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector.id"), primary_key=True
|
||||
)
|
||||
@@ -587,22 +583,6 @@ class Document(Base):
|
||||
)
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# tables for the knowledge graph data
|
||||
kg_processed: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
comment="Whether this document has been processed for knowledge graph extraction",
|
||||
)
|
||||
|
||||
kg_data: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Knowledge graph data extracted from this document",
|
||||
)
|
||||
|
||||
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="document"
|
||||
)
|
||||
@@ -621,304 +601,6 @@ class Document(Base):
|
||||
)
|
||||
|
||||
|
||||
class KGEntityType(Base):
|
||||
__tablename__ = "kg_entity_type"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, nullable=False, index=True
|
||||
)
|
||||
|
||||
description: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
|
||||
|
||||
grounding: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=False, index=False
|
||||
)
|
||||
grounded_source_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=False, index=False
|
||||
)
|
||||
|
||||
ge_determine_instructions: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True, default=None
|
||||
)
|
||||
|
||||
ge_grounding_signature: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=False, default=None
|
||||
)
|
||||
|
||||
clustering: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Clustering information for this entity type",
|
||||
)
|
||||
|
||||
classification_requirements: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Pre-extraction classification requirements and instructions",
|
||||
)
|
||||
|
||||
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
extraction_sources: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
comment="Sources and methods used to extract this entity",
|
||||
)
|
||||
|
||||
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class KGRelationshipType(Base):
|
||||
__tablename__ = "kg_relationship_type"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
source_entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
target_entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
definition: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
comment="Whether this relationship type represents a definition",
|
||||
)
|
||||
|
||||
clustering: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Clustering information for this relationship type",
|
||||
)
|
||||
|
||||
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# Relationships to EntityType
|
||||
source_type: Mapped["KGEntityType"] = relationship(
|
||||
"KGEntityType",
|
||||
foreign_keys=[source_entity_type_id_name],
|
||||
backref="source_relationship_type",
|
||||
)
|
||||
target_type: Mapped["KGEntityType"] = relationship(
|
||||
"KGEntityType",
|
||||
foreign_keys=[target_entity_type_id_name],
|
||||
backref="target_relationship_type",
|
||||
)
|
||||
|
||||
|
||||
class KGEntity(Base):
|
||||
__tablename__ = "kg_entity"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString, primary_key=True, index=True
|
||||
)
|
||||
|
||||
# Basic entity information
|
||||
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
document_id: Mapped[str | None] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=True
|
||||
)
|
||||
|
||||
alternative_names: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Reference to KGEntityType
|
||||
entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationship to KGEntityType
|
||||
entity_type: Mapped["KGEntityType"] = relationship("KGEntityType", backref="entity")
|
||||
|
||||
description: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
keywords: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Access control
|
||||
acl: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Boosts - using JSON for flexibility
|
||||
boosts: Mapped[dict] = mapped_column(postgresql.JSONB, nullable=False, default=dict)
|
||||
|
||||
event_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
comment="Time of the event being processed",
|
||||
)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Fixed column names in indexes
|
||||
Index("ix_entity_type_acl", entity_type_id_name, acl),
|
||||
Index("ix_entity_name_search", name, entity_type_id_name),
|
||||
)
|
||||
|
||||
|
||||
class KGRelationship(Base):
|
||||
__tablename__ = "kg_relationship"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString, primary_key=True, index=True
|
||||
)
|
||||
|
||||
# Source and target nodes (foreign keys to Entity table)
|
||||
source_node: Mapped[str] = mapped_column(
|
||||
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
|
||||
)
|
||||
|
||||
target_node: Mapped[str] = mapped_column(
|
||||
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
|
||||
)
|
||||
|
||||
# Relationship type
|
||||
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
# Add new relationship type reference
|
||||
relationship_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_relationship_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Add the SQLAlchemy relationship property
|
||||
relationship_type: Mapped["KGRelationshipType"] = relationship(
|
||||
"KGRelationshipType", backref="relationship"
|
||||
)
|
||||
|
||||
cluster_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# Relationships to Entity table
|
||||
source: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[source_node])
|
||||
target: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[target_node])
|
||||
|
||||
__table_args__ = (
|
||||
# Index for querying relationships by type
|
||||
Index("ix_kg_relationship_type", type),
|
||||
# Composite index for source/target queries
|
||||
Index("ix_kg_relationship_nodes", source_node, target_node),
|
||||
# Ensure unique relationships between nodes of a specific type
|
||||
UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_source_target_type",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class KGTerm(Base):
|
||||
__tablename__ = "kg_term"
|
||||
|
||||
# Make id_term the primary key
|
||||
id_term: Mapped[str] = mapped_column(
|
||||
NullFilteredString, primary_key=True, nullable=False, index=True
|
||||
)
|
||||
|
||||
# List of entity types this term applies to
|
||||
entity_types: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Index for searching terms with specific entity types
|
||||
Index("ix_search_term_entities", entity_types),
|
||||
# Index for term lookups
|
||||
Index("ix_search_term_term", id_term),
|
||||
)
|
||||
|
||||
|
||||
class ChunkStats(Base):
|
||||
__tablename__ = "chunk_stats"
|
||||
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
|
||||
@@ -1006,14 +688,6 @@ class Connector(Base):
|
||||
indexing_start: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime, nullable=True
|
||||
)
|
||||
|
||||
kg_extraction_enabled: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
comment="Whether this connector should extract knowledge graph entities",
|
||||
)
|
||||
|
||||
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -1457,12 +1131,6 @@ class DocumentByConnectorCredentialPair(Base):
|
||||
# the actual indexing is complete
|
||||
has_been_indexed: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
has_been_kg_processed: Mapped[bool | None] = mapped_column(
|
||||
Boolean,
|
||||
nullable=True,
|
||||
comment="Whether this document has been processed for knowledge graph extraction",
|
||||
)
|
||||
|
||||
connector: Mapped[Connector] = relationship(
|
||||
"Connector", back_populates="documents_by_connector", passive_deletes=True
|
||||
)
|
||||
|
||||
@@ -1,298 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import KGRelationship
|
||||
from onyx.db.models import KGRelationshipType
|
||||
from onyx.kg.utils.formatting_utils import format_entity
|
||||
from onyx.kg.utils.formatting_utils import format_relationship
|
||||
from onyx.kg.utils.formatting_utils import generate_relationship_type
|
||||
|
||||
|
||||
def add_relationship(
|
||||
db_session: Session,
|
||||
relationship_id_name: str,
|
||||
cluster_count: int | None = None,
|
||||
) -> "KGRelationship":
|
||||
"""
|
||||
Add a relationship between two entities to the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
source_entity_id: ID of the source entity
|
||||
relationship_type: Type of relationship
|
||||
target_entity_id: ID of the target entity
|
||||
cluster_count: Optional count of similar relationships clustered together
|
||||
|
||||
Returns:
|
||||
The created KGRelationship object
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.IntegrityError: If the relationship already exists or entities don't exist
|
||||
"""
|
||||
# Generate a unique ID for the relationship
|
||||
|
||||
(
|
||||
source_entity_id_name,
|
||||
relationship_string,
|
||||
target_entity_id_name,
|
||||
) = relationship_id_name.split("__")
|
||||
|
||||
source_entity_id_name = format_entity(source_entity_id_name)
|
||||
target_entity_id_name = format_entity(target_entity_id_name)
|
||||
relationship_id_name = format_relationship(relationship_id_name)
|
||||
relationship_type = generate_relationship_type(relationship_id_name)
|
||||
|
||||
# Create new relationship
|
||||
relationship = KGRelationship(
|
||||
id_name=relationship_id_name,
|
||||
source_node=source_entity_id_name,
|
||||
target_node=target_entity_id_name,
|
||||
type=relationship_string.lower(),
|
||||
relationship_type_id_name=relationship_type,
|
||||
cluster_count=cluster_count,
|
||||
)
|
||||
|
||||
db_session.add(relationship)
|
||||
db_session.flush() # Flush to get any DB errors early
|
||||
|
||||
return relationship
|
||||
|
||||
|
||||
def add_or_increment_relationship(
|
||||
db_session: Session,
|
||||
relationship_id_name: str,
|
||||
) -> "KGRelationship":
|
||||
"""
|
||||
Add a relationship between two entities to the database if it doesn't exist,
|
||||
or increment its cluster_count by 1 if it already exists.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
relationship_id_name: The ID name of the relationship in format "source__relationship__target"
|
||||
|
||||
Returns:
|
||||
The created or updated KGRelationship object
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.IntegrityError: If there's an error with the database operation
|
||||
"""
|
||||
# Format the relationship_id_name
|
||||
relationship_id_name = format_relationship(relationship_id_name)
|
||||
|
||||
# Check if the relationship already exists
|
||||
existing_relationship = (
|
||||
db_session.query(KGRelationship)
|
||||
.filter(KGRelationship.id_name == relationship_id_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_relationship:
|
||||
# If it exists, increment the cluster_count
|
||||
existing_relationship.cluster_count = (
|
||||
existing_relationship.cluster_count or 0
|
||||
) + 1
|
||||
db_session.flush()
|
||||
return existing_relationship
|
||||
else:
|
||||
# If it doesn't exist, add it with cluster_count=1
|
||||
return add_relationship(db_session, relationship_id_name, cluster_count=1)
|
||||
|
||||
|
||||
def add_relationship_type(
|
||||
db_session: Session,
|
||||
source_entity_type: str,
|
||||
relationship_type: str,
|
||||
target_entity_type: str,
|
||||
definition: bool = False,
|
||||
extraction_count: int = 0,
|
||||
) -> "KGRelationshipType":
|
||||
"""
|
||||
Add a new relationship type to the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
source_entity_type: Type of the source entity
|
||||
relationship_type: Type of relationship
|
||||
target_entity_type: Type of the target entity
|
||||
definition: Whether this relationship type represents a definition (default False)
|
||||
|
||||
Returns:
|
||||
The created KGRelationshipType object
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.IntegrityError: If the relationship type already exists
|
||||
"""
|
||||
|
||||
id_name = f"{source_entity_type.upper()}__{relationship_type}__{target_entity_type.upper()}"
|
||||
# Create new relationship type
|
||||
rel_type = KGRelationshipType(
|
||||
id_name=id_name,
|
||||
name=relationship_type,
|
||||
source_entity_type_id_name=source_entity_type.upper(),
|
||||
target_entity_type_id_name=target_entity_type.upper(),
|
||||
definition=definition,
|
||||
cluster_count=extraction_count,
|
||||
type=relationship_type, # Using the relationship_type as the type
|
||||
active=True, # Setting as active by default
|
||||
)
|
||||
|
||||
db_session.add(rel_type)
|
||||
db_session.flush() # Flush to get any DB errors early
|
||||
|
||||
return rel_type
|
||||
|
||||
|
||||
def get_all_relationship_types(db_session: Session) -> list["KGRelationshipType"]:
|
||||
"""
|
||||
Retrieve all relationship types from the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
List of KGRelationshipType objects
|
||||
"""
|
||||
return db_session.query(KGRelationshipType).all()
|
||||
|
||||
|
||||
def get_all_relationships(db_session: Session) -> list["KGRelationship"]:
|
||||
"""
|
||||
Retrieve all relationships from the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
List of KGRelationship objects
|
||||
"""
|
||||
return db_session.query(KGRelationship).all()
|
||||
|
||||
|
||||
def delete_relationships_by_id_names(db_session: Session, id_names: list[str]) -> int:
|
||||
"""
|
||||
Delete relationships from the database based on a list of id_names.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
id_names: List of relationship id_names to delete
|
||||
|
||||
Returns:
|
||||
Number of relationships deleted
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
|
||||
"""
|
||||
deleted_count = (
|
||||
db_session.query(KGRelationship)
|
||||
.filter(KGRelationship.id_name.in_(id_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
db_session.flush() # Flush to ensure deletion is processed
|
||||
return deleted_count
|
||||
|
||||
|
||||
def delete_relationship_types_by_id_names(
|
||||
db_session: Session, id_names: list[str]
|
||||
) -> int:
|
||||
"""
|
||||
Delete relationship types from the database based on a list of id_names.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
id_names: List of relationship type id_names to delete
|
||||
|
||||
Returns:
|
||||
Number of relationship types deleted
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
|
||||
"""
|
||||
deleted_count = (
|
||||
db_session.query(KGRelationshipType)
|
||||
.filter(KGRelationshipType.id_name.in_(id_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
db_session.flush() # Flush to ensure deletion is processed
|
||||
return deleted_count
|
||||
|
||||
|
||||
def get_relationships_for_entity_type_pairs(
|
||||
db_session: Session, entity_type_pairs: list[tuple[str, str]]
|
||||
) -> list["KGRelationshipType"]:
|
||||
"""
|
||||
Get relationship types from the database based on a list of entity type pairs.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
entity_type_pairs: List of tuples where each tuple contains (source_entity_type, target_entity_type)
|
||||
|
||||
Returns:
|
||||
List of KGRelationshipType objects where source and target types match the provided pairs
|
||||
"""
|
||||
|
||||
conditions = [
|
||||
(
|
||||
(KGRelationshipType.source_entity_type_id_name == source_type)
|
||||
& (KGRelationshipType.target_entity_type_id_name == target_type)
|
||||
)
|
||||
for source_type, target_type in entity_type_pairs
|
||||
]
|
||||
|
||||
return db_session.query(KGRelationshipType).filter(or_(*conditions)).all()
|
||||
|
||||
|
||||
def get_allowed_relationship_type_pairs(
|
||||
db_session: Session, entities: list[str]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get the allowed relationship pairs for the given entities.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
entities: List of entity type ID names to filter by
|
||||
|
||||
Returns:
|
||||
List of id_names from KGRelationshipType where both source and target entity types
|
||||
are in the provided entities list
|
||||
"""
|
||||
entity_types = list(set([entity.split(":")[0] for entity in entities]))
|
||||
|
||||
return [
|
||||
row[0]
|
||||
for row in (
|
||||
db_session.query(KGRelationshipType.id_name)
|
||||
.filter(KGRelationshipType.source_entity_type_id_name.in_(entity_types))
|
||||
.filter(KGRelationshipType.target_entity_type_id_name.in_(entity_types))
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def get_relationships_of_entity(db_session: Session, entity_id: str) -> List[str]:
|
||||
"""Get all relationship ID names where the given entity is either the source or target node.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity_id: ID of the entity to find relationships for
|
||||
|
||||
Returns:
|
||||
List of relationship ID names where the entity is either source or target
|
||||
"""
|
||||
return [
|
||||
row[0]
|
||||
for row in (
|
||||
db_session.query(KGRelationship.id_name)
|
||||
.filter(
|
||||
or_(
|
||||
KGRelationship.source_node == entity_id,
|
||||
KGRelationship.target_node == entity_id,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
]
|
||||
@@ -4,8 +4,6 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
@@ -98,7 +96,7 @@ class VespaDocumentFields:
|
||||
understandable like this for now.
|
||||
"""
|
||||
|
||||
# all other fields except these 4 and knowledge graph will always be left alone by the update request
|
||||
# all other fields except these 4 will always be left alone by the update request
|
||||
access: DocumentAccess | None = None
|
||||
document_sets: set[str] | None = None
|
||||
boost: float | None = None
|
||||
@@ -353,9 +351,7 @@ class HybridCapable(abc.ABC):
|
||||
hybrid_alpha: float,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
ranking_profile_type: QueryExpansionType,
|
||||
offset: int = 0,
|
||||
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""
|
||||
Run hybrid search and return a list of inference chunks.
|
||||
|
||||
@@ -85,25 +85,6 @@ schema DANSWER_CHUNK_NAME {
|
||||
indexing: attribute
|
||||
}
|
||||
|
||||
# Separate array fields for knowledge graph data
|
||||
field kg_entities type weightedset<string> {
|
||||
rank: filter
|
||||
indexing: summary | attribute
|
||||
attribute: fast-search
|
||||
}
|
||||
|
||||
field kg_relationships type weightedset<string> {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
|
||||
field kg_terms type weightedset<string> {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
|
||||
# Needs to have a separate Attribute list for efficient filtering
|
||||
field metadata_list type array<string> {
|
||||
indexing: summary | attribute
|
||||
@@ -195,7 +176,7 @@ schema DANSWER_CHUNK_NAME {
|
||||
match-features: recency_bias
|
||||
}
|
||||
|
||||
rank-profile hybrid_search_semantic_base_VARIABLE_DIM inherits default, default_rank {
|
||||
rank-profile hybrid_searchVARIABLE_DIM inherits default, default_rank {
|
||||
inputs {
|
||||
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
|
||||
}
|
||||
@@ -211,75 +192,7 @@ schema DANSWER_CHUNK_NAME {
|
||||
|
||||
# First phase must be vector to allow hits that have no keyword matches
|
||||
first-phase {
|
||||
expression: query(title_content_ratio) * closeness(field, title_embedding) + (1 - query(title_content_ratio)) * closeness(field, embeddings)
|
||||
}
|
||||
|
||||
# Weighted average between Vector Search and BM-25
|
||||
global-phase {
|
||||
expression {
|
||||
(
|
||||
# Weighted Vector Similarity Score
|
||||
(
|
||||
query(alpha) * (
|
||||
(query(title_content_ratio) * normalize_linear(title_vector_score))
|
||||
+
|
||||
((1 - query(title_content_ratio)) * normalize_linear(closeness(field, embeddings)))
|
||||
)
|
||||
)
|
||||
|
||||
+
|
||||
|
||||
# Weighted Keyword Similarity Score
|
||||
# Note: for the BM25 Title score, it requires decent stopword removal in the query
|
||||
# This needs to be the case so there aren't irrelevant titles being normalized to a score of 1
|
||||
(
|
||||
(1 - query(alpha)) * (
|
||||
(query(title_content_ratio) * normalize_linear(bm25(title)))
|
||||
+
|
||||
((1 - query(title_content_ratio)) * normalize_linear(bm25(content)))
|
||||
)
|
||||
)
|
||||
)
|
||||
# Boost based on user feedback
|
||||
* document_boost
|
||||
# Decay factor based on time document was last updated
|
||||
* recency_bias
|
||||
# Boost based on aggregated boost calculation
|
||||
* aggregated_chunk_boost
|
||||
}
|
||||
rerank-count: 1000
|
||||
}
|
||||
|
||||
match-features {
|
||||
bm25(title)
|
||||
bm25(content)
|
||||
closeness(field, title_embedding)
|
||||
closeness(field, embeddings)
|
||||
document_boost
|
||||
recency_bias
|
||||
aggregated_chunk_boost
|
||||
closest(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
rank-profile hybrid_search_keyword_base_VARIABLE_DIM inherits default, default_rank {
|
||||
inputs {
|
||||
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
|
||||
}
|
||||
|
||||
function title_vector_score() {
|
||||
expression {
|
||||
# If no good matching titles, then it should use the context embeddings rather than having some
|
||||
# irrelevant title have a vector score of 1. This way at least it will be the doc with the highest
|
||||
# matching content score getting the full score
|
||||
max(closeness(field, embeddings), closeness(field, title_embedding))
|
||||
}
|
||||
}
|
||||
|
||||
# First phase must be vector to allow hits that have no keyword matches
|
||||
first-phase {
|
||||
expression: query(title_content_ratio) * bm25(title) + (1 - query(title_content_ratio)) * bm25(content)
|
||||
expression: closeness(field, embeddings)
|
||||
}
|
||||
|
||||
# Weighted average between Vector Search and BM-25
|
||||
|
||||
@@ -166,19 +166,18 @@ def _get_chunks_via_visit_api(
|
||||
|
||||
# build the list of fields to retrieve
|
||||
field_set_list = (
|
||||
[] if not field_names else [f"{field_name}" for field_name in field_names]
|
||||
None
|
||||
if not field_names
|
||||
else [f"{index_name}:{field_name}" for field_name in field_names]
|
||||
)
|
||||
acl_fieldset_entry = f"{ACCESS_CONTROL_LIST}"
|
||||
acl_fieldset_entry = f"{index_name}:{ACCESS_CONTROL_LIST}"
|
||||
if (
|
||||
field_set_list
|
||||
and filters.access_control_list
|
||||
and acl_fieldset_entry not in field_set_list
|
||||
):
|
||||
field_set_list.append(acl_fieldset_entry)
|
||||
if field_set_list:
|
||||
field_set = f"{index_name}:" + ",".join(field_set_list)
|
||||
else:
|
||||
field_set = None
|
||||
field_set = ",".join(field_set_list) if field_set_list else None
|
||||
|
||||
# build filters
|
||||
selection = f"{index_name}.document_id=='{chunk_request.document_id}'"
|
||||
|
||||
@@ -17,10 +17,8 @@ from uuid import UUID
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
@@ -30,7 +28,6 @@ from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.document_index_utils import get_document_chunk_ids
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
@@ -102,37 +99,6 @@ class _VespaUpdateRequest:
|
||||
update_request: dict[str, dict]
|
||||
|
||||
|
||||
class KGVespaChunkUpdateRequest(BaseModel):
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
url: str
|
||||
update_request: dict[str, dict]
|
||||
|
||||
|
||||
class KGUChunkUpdateRequest(BaseModel):
|
||||
"""
|
||||
Update KG fields for a document
|
||||
"""
|
||||
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
core_entity: str
|
||||
entities: set[str] | None = None
|
||||
relationships: set[str] | None = None
|
||||
terms: set[str] | None = None
|
||||
|
||||
|
||||
class KGUDocumentUpdateRequest(BaseModel):
|
||||
"""
|
||||
Update KG fields for a document
|
||||
"""
|
||||
|
||||
document_id: str
|
||||
entities: set[str]
|
||||
relationships: set[str]
|
||||
terms: set[str]
|
||||
|
||||
|
||||
def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
@@ -537,51 +503,6 @@ class VespaIndex(DocumentIndex):
|
||||
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
@classmethod
|
||||
def _apply_kg_chunk_updates_batched(
|
||||
cls,
|
||||
updates: list[KGVespaChunkUpdateRequest],
|
||||
httpx_client: httpx.Client,
|
||||
batch_size: int = BATCH_SIZE,
|
||||
) -> None:
|
||||
"""Runs a batch of updates in parallel via the ThreadPoolExecutor."""
|
||||
|
||||
def _kg_update_chunk(
|
||||
update: KGVespaChunkUpdateRequest, http_client: httpx.Client
|
||||
) -> httpx.Response:
|
||||
# logger.debug(
|
||||
# f"Updating KG with request to {update.url} with body {update.update_request}"
|
||||
# )
|
||||
return http_client.put(
|
||||
update.url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update.update_request,
|
||||
)
|
||||
|
||||
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficient for
|
||||
# indexing / updates / deletes since we have to make a large volume of requests.
|
||||
|
||||
with (
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
|
||||
httpx_client as http_client,
|
||||
):
|
||||
for update_batch in batch_generator(updates, batch_size):
|
||||
future_to_document_id = {
|
||||
executor.submit(
|
||||
_kg_update_chunk,
|
||||
update,
|
||||
http_client,
|
||||
): update.document_id
|
||||
for update in update_batch
|
||||
}
|
||||
for future in concurrent.futures.as_completed(future_to_document_id):
|
||||
res = future.result()
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
|
||||
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
|
||||
|
||||
@@ -665,89 +586,6 @@ class VespaIndex(DocumentIndex):
|
||||
time.monotonic() - update_start,
|
||||
)
|
||||
|
||||
def kg_chunk_updates(
|
||||
self, kg_update_requests: list[KGUChunkUpdateRequest], tenant_id: str
|
||||
) -> None:
|
||||
def _get_general_entity(specific_entity: str) -> str:
|
||||
entity_type, entity_name = specific_entity.split(":")
|
||||
if entity_type != "*":
|
||||
return f"{entity_type}:*"
|
||||
else:
|
||||
return specific_entity
|
||||
|
||||
processed_updates_requests: list[KGVespaChunkUpdateRequest] = []
|
||||
logger.debug(f"Updating {len(kg_update_requests)} documents in Vespa")
|
||||
|
||||
update_start = time.monotonic()
|
||||
|
||||
# Build the _VespaUpdateRequest objects
|
||||
|
||||
for kg_update_request in kg_update_requests:
|
||||
kg_update_dict: dict[str, dict] = {"fields": {}}
|
||||
|
||||
implied_entities = set()
|
||||
if kg_update_request.relationships is not None:
|
||||
for kg_relationship in kg_update_request.relationships:
|
||||
kg_relationship_split = kg_relationship.split("__")
|
||||
if len(kg_relationship_split) == 3:
|
||||
implied_entities.add(kg_relationship_split[0])
|
||||
implied_entities.add(kg_relationship_split[2])
|
||||
# Keep this for now in case we want to also add the general entities
|
||||
# implied_entities.add(_get_general_entity(kg_relationship_split[0]))
|
||||
# implied_entities.add(_get_general_entity(kg_relationship_split[2]))
|
||||
|
||||
kg_update_dict["fields"]["kg_relationships"] = {
|
||||
"assign": {
|
||||
kg_relationship: 1
|
||||
for kg_relationship in kg_update_request.relationships
|
||||
}
|
||||
}
|
||||
|
||||
if kg_update_request.entities is not None or implied_entities:
|
||||
if kg_update_request.entities is None:
|
||||
kg_entities = implied_entities
|
||||
else:
|
||||
kg_entities = set(kg_update_request.entities)
|
||||
kg_entities.update(implied_entities)
|
||||
|
||||
kg_update_dict["fields"]["kg_entities"] = {
|
||||
"assign": {kg_entity: 1 for kg_entity in kg_entities}
|
||||
}
|
||||
|
||||
if kg_update_request.terms is not None:
|
||||
kg_update_dict["fields"]["kg_terms"] = {
|
||||
"assign": {kg_term: 1 for kg_term in kg_update_request.terms}
|
||||
}
|
||||
|
||||
if not kg_update_dict["fields"]:
|
||||
logger.error("Update request received but nothing to update")
|
||||
continue
|
||||
|
||||
doc_chunk_id = get_uuid_from_chunk_info(
|
||||
document_id=kg_update_request.document_id,
|
||||
chunk_id=kg_update_request.chunk_id,
|
||||
tenant_id=tenant_id,
|
||||
large_chunk_id=None,
|
||||
)
|
||||
|
||||
processed_updates_requests.append(
|
||||
KGVespaChunkUpdateRequest(
|
||||
document_id=kg_update_request.document_id,
|
||||
chunk_id=kg_update_request.chunk_id,
|
||||
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
|
||||
update_request=kg_update_dict,
|
||||
)
|
||||
)
|
||||
|
||||
with self.httpx_client_context as httpx_client:
|
||||
self._apply_kg_chunk_updates_batched(
|
||||
processed_updates_requests, httpx_client
|
||||
)
|
||||
logger.debug(
|
||||
"Finished updating Vespa documents in %.2f seconds",
|
||||
time.monotonic() - update_start,
|
||||
)
|
||||
|
||||
@retry(
|
||||
tries=3,
|
||||
delay=1,
|
||||
@@ -962,14 +800,12 @@ class VespaIndex(DocumentIndex):
|
||||
hybrid_alpha: float,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
ranking_profile_type: QueryExpansionType,
|
||||
offset: int = 0,
|
||||
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
vespa_where_clauses = build_vespa_filters(filters)
|
||||
# Needs to be at least as much as the value set in Vespa schema config
|
||||
target_hits = max(10 * num_to_retrieve, 1000)
|
||||
|
||||
yql = (
|
||||
YQL_BASE.format(index_name=self.index_name)
|
||||
+ vespa_where_clauses
|
||||
@@ -981,11 +817,6 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
final_query = " ".join(final_keywords) if final_keywords else query
|
||||
|
||||
if ranking_profile_type == QueryExpansionType.KEYWORD:
|
||||
ranking_profile = f"hybrid_search_keyword_base_{len(query_embedding)}"
|
||||
else:
|
||||
ranking_profile = f"hybrid_search_semantic_base_{len(query_embedding)}"
|
||||
|
||||
logger.debug(f"Query YQL: {yql}")
|
||||
|
||||
params: dict[str, str | int | float] = {
|
||||
@@ -1001,7 +832,7 @@ class VespaIndex(DocumentIndex):
|
||||
),
|
||||
"hits": num_to_retrieve,
|
||||
"offset": offset,
|
||||
"ranking.profile": ranking_profile,
|
||||
"ranking.profile": f"hybrid_search{len(query_embedding)}",
|
||||
"timeout": VESPA_TIMEOUT,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
|
||||
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
|
||||
from onyx.document_index.vespa.index import IndexFilters
|
||||
from onyx.document_index.vespa.index import KGUChunkUpdateRequest
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
# from backend.onyx.chat.process_message import get_inference_chunks
|
||||
# from backend.onyx.document_index.vespa.index import VespaIndex
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class KGChunkInfo(BaseModel):
|
||||
kg_relationships: dict[str, int]
|
||||
kg_entities: dict[str, int]
|
||||
kg_terms: dict[str, int]
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def get_document_kg_info(
|
||||
document_id: str,
|
||||
index_name: str,
|
||||
filters: IndexFilters | None = None,
|
||||
) -> dict | None:
|
||||
"""
|
||||
Retrieve the kg_info attribute from a Vespa document by its document_id.
|
||||
|
||||
Args:
|
||||
document_id: The unique identifier of the document.
|
||||
index_name: The name of the Vespa index to query.
|
||||
filters: Optional access control filters to apply.
|
||||
|
||||
Returns:
|
||||
The kg_info dictionary if found, None otherwise.
|
||||
"""
|
||||
# Use the existing visit API infrastructure
|
||||
kg_doc_info: dict[int, KGChunkInfo] = {}
|
||||
|
||||
document_chunks = _get_chunks_via_visit_api(
|
||||
chunk_request=VespaChunkRequest(document_id=document_id),
|
||||
index_name=index_name,
|
||||
filters=filters or IndexFilters(access_control_list=None),
|
||||
field_names=["kg_relationships", "kg_entities", "kg_terms"],
|
||||
get_large_chunks=False,
|
||||
)
|
||||
|
||||
for chunk_id, document_chunk in enumerate(document_chunks):
|
||||
kg_chunk_info = KGChunkInfo(
|
||||
kg_relationships=document_chunk["fields"].get("kg_relationships", {}),
|
||||
kg_entities=document_chunk["fields"].get("kg_entities", {}),
|
||||
kg_terms=document_chunk["fields"].get("kg_terms", {}),
|
||||
)
|
||||
|
||||
kg_doc_info[chunk_id] = kg_chunk_info # TODO: check the chunk id is correct!
|
||||
|
||||
return kg_doc_info
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def update_kg_chunks_vespa_info(
|
||||
kg_update_requests: list[KGUChunkUpdateRequest],
|
||||
index_name: str,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
""" """
|
||||
# Use the existing visit API infrastructure
|
||||
vespa_index = VespaIndex(
|
||||
index_name=index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=False,
|
||||
multitenant=False,
|
||||
httpx_client=None,
|
||||
)
|
||||
|
||||
vespa_index.kg_chunk_updates(
|
||||
kg_update_requests=kg_update_requests, tenant_id=tenant_id
|
||||
)
|
||||
@@ -5,6 +5,7 @@ from datetime import timezone
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from onyx.document_index.vespa_constants import CHUNK_ID
|
||||
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
@@ -66,29 +67,6 @@ def build_vespa_filters(
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
|
||||
|
||||
def _build_kg_filter(
|
||||
kg_entities: list[str] | None,
|
||||
kg_relationships: list[str] | None,
|
||||
kg_terms: list[str] | None,
|
||||
) -> str:
|
||||
if not kg_entities and not kg_relationships and not kg_terms:
|
||||
return ""
|
||||
|
||||
filter_parts = []
|
||||
|
||||
# Process each filter type using the same pattern
|
||||
for filter_type, values in [
|
||||
("kg_entities", kg_entities),
|
||||
("kg_relationships", kg_relationships),
|
||||
("kg_terms", kg_terms),
|
||||
]:
|
||||
if values:
|
||||
filter_parts.append(
|
||||
" and ".join(f'({filter_type} contains "{val}") ' for val in values)
|
||||
)
|
||||
|
||||
return f"({' and '.join(filter_parts)}) and "
|
||||
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
@@ -128,13 +106,6 @@ def build_vespa_filters(
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
# KG filter
|
||||
filter_str += _build_kg_filter(
|
||||
kg_entities=filters.kg_entities,
|
||||
kg_relationships=filters.kg_relationships,
|
||||
kg_terms=filters.kg_terms,
|
||||
)
|
||||
|
||||
# Trim trailing " and "
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
|
||||
@@ -159,8 +159,6 @@ def load_files_from_zip(
|
||||
zip_metadata = json.load(metadata_file)
|
||||
if isinstance(zip_metadata, list):
|
||||
# convert list of dicts to dict of dicts
|
||||
# Use just the basename for matching since metadata may not include
|
||||
# the full path within the ZIP file
|
||||
zip_metadata = {d["filename"]: d for d in zip_metadata}
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Unable to load {DANSWER_METADATA_FILENAME}")
|
||||
@@ -178,13 +176,7 @@ def load_files_from_zip(
|
||||
continue
|
||||
|
||||
with zip_file.open(file_info.filename, "r") as subfile:
|
||||
# Try to match by exact filename first
|
||||
if file_info.filename in zip_metadata:
|
||||
yield file_info, subfile, zip_metadata.get(file_info.filename, {})
|
||||
else:
|
||||
# Then try matching by just the basename
|
||||
basename = os.path.basename(file_info.filename)
|
||||
yield file_info, subfile, zip_metadata.get(basename, {})
|
||||
yield file_info, subfile, zip_metadata.get(file_info.filename, {})
|
||||
|
||||
|
||||
def _extract_onyx_metadata(line: str) -> dict | None:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,229 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from thefuzz import process # type: ignore
|
||||
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_entities_for_types
|
||||
from onyx.db.relationships import get_relationships_for_entity_type_pairs
|
||||
from onyx.kg.models import NormalizedEntities
|
||||
from onyx.kg.models import NormalizedRelationships
|
||||
from onyx.kg.models import NormalizedTerms
|
||||
from onyx.kg.utils.embeddings import encode_string_batch
|
||||
|
||||
|
||||
def _get_existing_normalized_entities(raw_entities: List[str]) -> List[str]:
|
||||
"""
|
||||
Get existing normalized entities from the database.
|
||||
"""
|
||||
|
||||
entity_types = list(set([entity.split(":")[0] for entity in raw_entities]))
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entities = get_entities_for_types(db_session, entity_types)
|
||||
|
||||
return [entity.id_name for entity in entities]
|
||||
|
||||
|
||||
def _get_existing_normalized_relationships(
|
||||
raw_relationships: List[str],
|
||||
) -> Dict[str, Dict[str, List[str]]]:
|
||||
"""
|
||||
Get existing normalized relationships from the database.
|
||||
"""
|
||||
|
||||
relationship_type_map: Dict[str, Dict[str, List[str]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
relationship_pairs = list(
|
||||
set(
|
||||
[
|
||||
(
|
||||
relationship.split("__")[0].split(":")[0],
|
||||
relationship.split("__")[2].split(":")[0],
|
||||
)
|
||||
for relationship in raw_relationships
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
relationships = get_relationships_for_entity_type_pairs(
|
||||
db_session, relationship_pairs
|
||||
)
|
||||
|
||||
for relationship in relationships:
|
||||
relationship_type_map[relationship.source_entity_type_id_name][
|
||||
relationship.target_entity_type_id_name
|
||||
].append(relationship.id_name)
|
||||
|
||||
return relationship_type_map
|
||||
|
||||
|
||||
def normalize_entities(raw_entities: List[str]) -> NormalizedEntities:
|
||||
"""
|
||||
Match each entity against a list of normalized entities using fuzzy matching.
|
||||
Returns the best matching normalized entity for each input entity.
|
||||
|
||||
Args:
|
||||
entities: List of entity strings to normalize
|
||||
|
||||
Returns:
|
||||
List of normalized entity strings
|
||||
"""
|
||||
# Assume this is your predefined list of normalized entities
|
||||
norm_entities = _get_existing_normalized_entities(raw_entities)
|
||||
|
||||
normalized_results: List[str] = []
|
||||
normalized_map: Dict[str, str | None] = {}
|
||||
threshold = 80 # Adjust threshold as needed
|
||||
|
||||
for entity in raw_entities:
|
||||
# Find the best match and its score from norm_entities
|
||||
best_match, score = process.extractOne(entity, norm_entities)
|
||||
|
||||
if score >= threshold:
|
||||
normalized_results.append(best_match)
|
||||
normalized_map[entity] = best_match
|
||||
else:
|
||||
# If no good match found, keep original
|
||||
normalized_map[entity] = None
|
||||
|
||||
return NormalizedEntities(
|
||||
entities=normalized_results, entity_normalization_map=normalized_map
|
||||
)
|
||||
|
||||
|
||||
def normalize_relationships(
|
||||
raw_relationships: List[str], entity_normalization_map: Dict[str, Optional[str]]
|
||||
) -> NormalizedRelationships:
|
||||
"""
|
||||
Normalize relationships using entity mappings and relationship string matching.
|
||||
|
||||
Args:
|
||||
relationships: List of relationships in format "source__relation__target"
|
||||
entity_normalization_map: Mapping of raw entities to normalized ones (or None)
|
||||
|
||||
Returns:
|
||||
NormalizedRelationships containing normalized relationships and mapping
|
||||
"""
|
||||
# Placeholder for normalized relationship structure
|
||||
nor_relationships = _get_existing_normalized_relationships(raw_relationships)
|
||||
|
||||
normalized_rels: List[str] = []
|
||||
normalization_map: Dict[str, str | None] = {}
|
||||
|
||||
for raw_rel in raw_relationships:
|
||||
# 1. Split and normalize entities
|
||||
try:
|
||||
source, rel_string, target = raw_rel.split("__")
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid relationship format: {raw_rel}")
|
||||
|
||||
# Check if entities are in normalization map and not None
|
||||
norm_source = entity_normalization_map.get(source)
|
||||
norm_target = entity_normalization_map.get(target)
|
||||
|
||||
if norm_source is None or norm_target is None:
|
||||
normalization_map[raw_rel] = None
|
||||
continue
|
||||
|
||||
# 2. Find candidate normalized relationships
|
||||
candidate_rels = []
|
||||
norm_source_type = norm_source.split(":")[0]
|
||||
norm_target_type = norm_target.split(":")[0]
|
||||
if (
|
||||
norm_source_type in nor_relationships
|
||||
and norm_target_type in nor_relationships[norm_source_type]
|
||||
):
|
||||
candidate_rels = [
|
||||
rel.split("__")[1]
|
||||
for rel in nor_relationships[norm_source_type][norm_target_type]
|
||||
]
|
||||
|
||||
if not candidate_rels:
|
||||
normalization_map[raw_rel] = None
|
||||
continue
|
||||
|
||||
# 3. Encode and find best match
|
||||
strings_to_encode = [rel_string] + candidate_rels
|
||||
vectors = encode_string_batch(strings_to_encode)
|
||||
|
||||
# Get raw relation vector and candidate vectors
|
||||
raw_vector = vectors[0]
|
||||
candidate_vectors = vectors[1:]
|
||||
|
||||
# Calculate dot products
|
||||
dot_products = np.dot(candidate_vectors, raw_vector)
|
||||
best_match_idx = np.argmax(dot_products)
|
||||
|
||||
# Create normalized relationship
|
||||
norm_rel = f"{norm_source}__{candidate_rels[best_match_idx]}__{norm_target}"
|
||||
normalized_rels.append(norm_rel)
|
||||
normalization_map[raw_rel] = norm_rel
|
||||
|
||||
return NormalizedRelationships(
|
||||
relationships=normalized_rels, relationship_normalization_map=normalization_map
|
||||
)
|
||||
|
||||
|
||||
def normalize_terms(raw_terms: List[str]) -> NormalizedTerms:
|
||||
"""
|
||||
Normalize terms using semantic similarity matching.
|
||||
|
||||
Args:
|
||||
terms: List of terms to normalize
|
||||
|
||||
Returns:
|
||||
NormalizedTerms containing normalized terms and mapping
|
||||
"""
|
||||
# # Placeholder for normalized terms - this would typically come from a predefined list
|
||||
# normalized_term_list = [
|
||||
# "algorithm",
|
||||
# "database",
|
||||
# "software",
|
||||
# "programming",
|
||||
# # ... other normalized terms ...
|
||||
# ]
|
||||
|
||||
# normalized_terms: List[str] = []
|
||||
# normalization_map: Dict[str, str | None] = {}
|
||||
|
||||
# if not raw_terms:
|
||||
# return NormalizedTerms(terms=[], term_normalization_map={})
|
||||
|
||||
# # Encode all terms at once for efficiency
|
||||
# strings_to_encode = raw_terms + normalized_term_list
|
||||
# vectors = encode_string_batch(strings_to_encode)
|
||||
|
||||
# # Split vectors into query terms and candidate terms
|
||||
# query_vectors = vectors[:len(raw_terms)]
|
||||
# candidate_vectors = vectors[len(raw_terms):]
|
||||
|
||||
# # Calculate similarity for each term
|
||||
# for i, term in enumerate(raw_terms):
|
||||
# # Calculate dot products with all candidates
|
||||
# similarities = np.dot(candidate_vectors, query_vectors[i])
|
||||
# best_match_idx = np.argmax(similarities)
|
||||
# best_match_score = similarities[best_match_idx]
|
||||
|
||||
# # Use a threshold to determine if the match is good enough
|
||||
# if best_match_score > 0.7: # Adjust threshold as needed
|
||||
# normalized_term = normalized_term_list[best_match_idx]
|
||||
# normalized_terms.append(normalized_term)
|
||||
# normalization_map[term] = normalized_term
|
||||
# else:
|
||||
# # If no good match found, keep original
|
||||
# normalization_map[term] = None
|
||||
|
||||
# return NormalizedTerms(
|
||||
# terms=normalized_terms,
|
||||
# term_normalization_map=normalization_map
|
||||
# )
|
||||
|
||||
return NormalizedTerms(
|
||||
terms=raw_terms, term_normalization_map={term: term for term in raw_terms}
|
||||
)
|
||||
@@ -1,191 +0,0 @@
|
||||
from onyx.configs.kg_configs import KG_IGNORE_EMAIL_DOMAINS
|
||||
from onyx.configs.kg_configs import KG_OWN_COMPANY
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_kg_entity_by_document
|
||||
from onyx.kg.context_preparations_extraction.models import ContextPreparation
|
||||
from onyx.kg.context_preparations_extraction.models import (
|
||||
KGDocumentClassificationPrompt,
|
||||
)
|
||||
from onyx.kg.models import KGChunkFormat
|
||||
from onyx.kg.models import KGClassificationContent
|
||||
from onyx.kg.utils.formatting_utils import generalize_entities
|
||||
from onyx.kg.utils.formatting_utils import kg_email_processing
|
||||
from onyx.prompts.kg_prompts import FIREFLIES_CHUNK_PREPROCESSING_PROMPT
|
||||
from onyx.prompts.kg_prompts import FIREFLIES_DOCUMENT_CLASSIFICATION_PROMPT
|
||||
|
||||
|
||||
def prepare_llm_content_fireflies(chunk: KGChunkFormat) -> ContextPreparation:
|
||||
"""
|
||||
Fireflies - prepare the content for the LLM.
|
||||
"""
|
||||
|
||||
document_id = chunk.document_id
|
||||
primary_owners = chunk.primary_owners
|
||||
secondary_owners = chunk.secondary_owners
|
||||
content = chunk.content
|
||||
chunk.title.capitalize()
|
||||
|
||||
implied_entities = set()
|
||||
implied_relationships = set()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
core_document = get_kg_entity_by_document(db_session, document_id)
|
||||
|
||||
if core_document:
|
||||
core_document_id_name = core_document.id_name
|
||||
else:
|
||||
core_document_id_name = f"FIREFLIES:{document_id}"
|
||||
|
||||
# Do we need this here?
|
||||
implied_entities.add(f"VENDOR:{KG_OWN_COMPANY}")
|
||||
implied_entities.add(f"{core_document_id_name}")
|
||||
implied_entities.add("FIREFLIES:*")
|
||||
implied_relationships.add(
|
||||
f"VENDOR:{KG_OWN_COMPANY}__relates_neutrally_to__{core_document_id_name}"
|
||||
)
|
||||
company_participant_emails = set()
|
||||
account_participant_emails = set()
|
||||
|
||||
for owner in primary_owners + secondary_owners:
|
||||
assert isinstance(KG_IGNORE_EMAIL_DOMAINS, list)
|
||||
|
||||
kg_owner = kg_email_processing(owner)
|
||||
if any(
|
||||
domain.lower() in kg_owner.company.lower()
|
||||
for domain in KG_IGNORE_EMAIL_DOMAINS
|
||||
):
|
||||
continue
|
||||
|
||||
if kg_owner.employee:
|
||||
company_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
|
||||
if kg_owner.name not in implied_entities:
|
||||
generalized_target_entity = list(
|
||||
generalize_entities([core_document_id_name])
|
||||
)[0]
|
||||
|
||||
implied_entities.add(f"EMPLOYEE:{kg_owner.name}")
|
||||
implied_relationships.add(
|
||||
f"EMPLOYEE:{kg_owner.name}__relates_neutrally_to__{core_document_id_name}"
|
||||
)
|
||||
implied_relationships.add(
|
||||
f"EMPLOYEE:{kg_owner.name}__relates_neutrally_to__{generalized_target_entity}"
|
||||
)
|
||||
implied_relationships.add(
|
||||
f"EMPLOYEE:*__relates_neutrally_to__{core_document_id_name}"
|
||||
)
|
||||
implied_relationships.add(
|
||||
f"EMPLOYEE:*__relates_neutrally_to__{generalized_target_entity}"
|
||||
)
|
||||
if kg_owner.company not in implied_entities:
|
||||
implied_entities.add(f"VENDOR:{kg_owner.company}")
|
||||
implied_relationships.add(
|
||||
f"VENDOR:{kg_owner.company}__relates_neutrally_to__{core_document_id_name}"
|
||||
)
|
||||
implied_relationships.add(
|
||||
f"VENDOR:{kg_owner.company}__relates_neutrally_to__{generalized_target_entity}"
|
||||
)
|
||||
|
||||
else:
|
||||
account_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
|
||||
if kg_owner.company not in implied_entities:
|
||||
implied_entities.add(f"ACCOUNT:{kg_owner.company}")
|
||||
implied_entities.add("ACCOUNT:*")
|
||||
implied_relationships.add(
|
||||
f"ACCOUNT:{kg_owner.company}__relates_neutrally_to__{core_document_id_name}"
|
||||
)
|
||||
implied_relationships.add(
|
||||
f"ACCOUNT:*__relates_neutrally_to__{core_document_id_name}"
|
||||
)
|
||||
|
||||
generalized_target_entity = list(
|
||||
generalize_entities([core_document_id_name])
|
||||
)[0]
|
||||
|
||||
implied_relationships.add(
|
||||
f"ACCOUNT:*__relates_neutrally_to__{generalized_target_entity}"
|
||||
)
|
||||
implied_relationships.add(
|
||||
f"ACCOUNT:{kg_owner.company}__relates_neutrally_to__{generalized_target_entity}"
|
||||
)
|
||||
|
||||
participant_string = "\n - " + "\n - ".join(company_participant_emails)
|
||||
account_participant_string = "\n - " + "\n - ".join(account_participant_emails)
|
||||
|
||||
llm_context = FIREFLIES_CHUNK_PREPROCESSING_PROMPT.format(
|
||||
participant_string=participant_string,
|
||||
account_participant_string=account_participant_string,
|
||||
content=content,
|
||||
)
|
||||
|
||||
return ContextPreparation(
|
||||
llm_context=llm_context,
|
||||
core_entity=core_document_id_name,
|
||||
implied_entities=list(implied_entities),
|
||||
implied_relationships=list(implied_relationships),
|
||||
implied_terms=[],
|
||||
)
|
||||
|
||||
|
||||
def prepare_llm_document_content_fireflies(
|
||||
document_classification_content: KGClassificationContent,
|
||||
category_list: str,
|
||||
category_definition_string: str,
|
||||
) -> KGDocumentClassificationPrompt:
|
||||
"""
|
||||
Fireflies - prepare prompt for the LLM classification.
|
||||
"""
|
||||
|
||||
prompt = FIREFLIES_DOCUMENT_CLASSIFICATION_PROMPT.format(
|
||||
beginning_of_call_content=document_classification_content.classification_content,
|
||||
category_list=category_list,
|
||||
category_options=category_definition_string,
|
||||
)
|
||||
|
||||
return KGDocumentClassificationPrompt(
|
||||
llm_prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
def get_classification_content_from_fireflies_chunks(
|
||||
first_num_classification_chunks: list[dict],
|
||||
) -> str:
|
||||
"""
|
||||
Creates a KGClassificationContent object from a list of Fireflies chunks.
|
||||
"""
|
||||
|
||||
assert isinstance(KG_IGNORE_EMAIL_DOMAINS, list)
|
||||
|
||||
primary_owners = first_num_classification_chunks[0]["fields"]["primary_owners"]
|
||||
secondary_owners = first_num_classification_chunks[0]["fields"]["secondary_owners"]
|
||||
|
||||
company_participant_emails = set()
|
||||
account_participant_emails = set()
|
||||
|
||||
for owner in primary_owners + secondary_owners:
|
||||
kg_owner = kg_email_processing(owner)
|
||||
if any(
|
||||
domain.lower() in kg_owner.company.lower()
|
||||
for domain in KG_IGNORE_EMAIL_DOMAINS
|
||||
):
|
||||
continue
|
||||
|
||||
if kg_owner.employee:
|
||||
company_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
|
||||
else:
|
||||
account_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
|
||||
|
||||
participant_string = "\n - " + "\n - ".join(company_participant_emails)
|
||||
account_participant_string = "\n - " + "\n - ".join(account_participant_emails)
|
||||
|
||||
title_string = first_num_classification_chunks[0]["fields"]["title"]
|
||||
content_string = "\n".join(
|
||||
[
|
||||
chunk_content["fields"]["content"]
|
||||
for chunk_content in first_num_classification_chunks
|
||||
]
|
||||
)
|
||||
|
||||
classification_content = f"{title_string}\n\nVendor Participants:\n{participant_string}\n\n\
|
||||
Other Participants:\n{account_participant_string}\n\nBeginning of Call:\n{content_string}"
|
||||
|
||||
return classification_content
|
||||
@@ -1,21 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ContextPreparation(BaseModel):
|
||||
"""
|
||||
Context preparation format for the LLM KG extraction.
|
||||
"""
|
||||
|
||||
llm_context: str
|
||||
core_entity: str
|
||||
implied_entities: list[str]
|
||||
implied_relationships: list[str]
|
||||
implied_terms: list[str]
|
||||
|
||||
|
||||
class KGDocumentClassificationPrompt(BaseModel):
|
||||
"""
|
||||
Document classification prompt format for the LLM KG extraction.
|
||||
"""
|
||||
|
||||
llm_prompt: str | None
|
||||
@@ -1,947 +0,0 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from onyx.db.connector import get_unprocessed_connector_ids
|
||||
from onyx.db.document import get_document_updated_at
|
||||
from onyx.db.document import get_unprocessed_kg_documents_for_connector
|
||||
from onyx.db.document import update_document_kg_info
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import add_entity
|
||||
from onyx.db.entities import get_entity_types
|
||||
from onyx.db.relationships import add_or_increment_relationship
|
||||
from onyx.db.relationships import add_relationship
|
||||
from onyx.db.relationships import add_relationship_type
|
||||
from onyx.document_index.vespa.index import KGUChunkUpdateRequest
|
||||
from onyx.document_index.vespa.index import KGUDocumentUpdateRequest
|
||||
from onyx.document_index.vespa.kg_interactions import update_kg_chunks_vespa_info
|
||||
from onyx.kg.models import ConnectorExtractionStats
|
||||
from onyx.kg.models import KGAggregatedExtractions
|
||||
from onyx.kg.models import KGBatchExtractionStats
|
||||
from onyx.kg.models import KGChunkExtraction
|
||||
from onyx.kg.models import KGChunkFormat
|
||||
from onyx.kg.models import KGChunkId
|
||||
from onyx.kg.models import KGClassificationContent
|
||||
from onyx.kg.models import KGClassificationDecisions
|
||||
from onyx.kg.models import KGClassificationInstructionStrings
|
||||
from onyx.kg.utils.chunk_preprocessing import prepare_llm_content
|
||||
from onyx.kg.utils.chunk_preprocessing import prepare_llm_document_content
|
||||
from onyx.kg.utils.formatting_utils import aggregate_kg_extractions
|
||||
from onyx.kg.utils.formatting_utils import generalize_entities
|
||||
from onyx.kg.utils.formatting_utils import generalize_relationships
|
||||
from onyx.kg.vespa.vespa_interactions import get_document_chunks_for_kg_processing
|
||||
from onyx.kg.vespa.vespa_interactions import (
|
||||
get_document_classification_content_for_kg_processing,
|
||||
)
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.prompts.kg_prompts import MASTER_EXTRACTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_classification_instructions() -> Dict[str, KGClassificationInstructionStrings]:
|
||||
"""
|
||||
Prepare the classification instructions for the given source.
|
||||
"""
|
||||
|
||||
classification_instructions_dict: Dict[str, KGClassificationInstructionStrings] = {}
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entity_types = get_entity_types(db_session, active=None)
|
||||
|
||||
for entity_type in entity_types:
|
||||
grounded_source_name = entity_type.grounded_source_name
|
||||
if grounded_source_name is None:
|
||||
continue
|
||||
classification_class_definitions = entity_type.classification_requirements
|
||||
|
||||
classification_options = ", ".join(classification_class_definitions.keys())
|
||||
|
||||
classification_instructions_dict[
|
||||
grounded_source_name
|
||||
] = KGClassificationInstructionStrings(
|
||||
classification_options=classification_options,
|
||||
classification_class_definitions=classification_class_definitions,
|
||||
)
|
||||
|
||||
return classification_instructions_dict
|
||||
|
||||
|
||||
def get_entity_types_str(active: bool | None = None) -> str:
|
||||
"""
|
||||
Get the entity types from the KGChunkExtraction model.
|
||||
"""
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
active_entity_types = get_entity_types(db_session, active)
|
||||
|
||||
entity_types_list = []
|
||||
for entity_type in active_entity_types:
|
||||
if entity_type.description:
|
||||
entity_description = "\n - Description: " + entity_type.description
|
||||
else:
|
||||
entity_description = ""
|
||||
if entity_type.ge_determine_instructions:
|
||||
allowed_options = "\n - Allowed Options: " + ", ".join(
|
||||
entity_type.ge_determine_instructions
|
||||
)
|
||||
else:
|
||||
allowed_options = ""
|
||||
entity_types_list.append(
|
||||
entity_type.id_name + entity_description + allowed_options
|
||||
)
|
||||
|
||||
return "\n".join(entity_types_list)
|
||||
|
||||
|
||||
def get_relationship_types_str(active: bool | None = None) -> str:
|
||||
"""
|
||||
Get the relationship types from the database.
|
||||
|
||||
Args:
|
||||
active: Filter by active status (True, False, or None for all)
|
||||
|
||||
Returns:
|
||||
A string with all relationship types formatted as "source_type__relationship_type__target_type"
|
||||
"""
|
||||
from onyx.db.relationships import get_all_relationship_types
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
relationship_types = get_all_relationship_types(db_session)
|
||||
|
||||
# Filter by active status if specified
|
||||
if active is not None:
|
||||
relationship_types = [
|
||||
rt for rt in relationship_types if rt.active == active
|
||||
]
|
||||
|
||||
relationship_types_list = []
|
||||
for rel_type in relationship_types:
|
||||
# Format as "source_type__relationship_type__target_type"
|
||||
formatted_type = f"{rel_type.source_entity_type_id_name}__{rel_type.type}__{rel_type.target_entity_type_id_name}"
|
||||
relationship_types_list.append(formatted_type)
|
||||
|
||||
return "\n".join(relationship_types_list)
|
||||
|
||||
|
||||
def kg_extraction_initialization(tenant_id: str, num_chunks: int = 1000) -> None:
|
||||
"""
|
||||
This extraction will create a random sample of chunks to process in order to perform
|
||||
clustering and topic modeling.
|
||||
"""
|
||||
|
||||
logger.info(f"Starting kg extraction for tenant {tenant_id}")
|
||||
|
||||
|
||||
def kg_extraction(
|
||||
tenant_id: str, index_name: str, processing_chunk_batch_size: int = 8
|
||||
) -> list[ConnectorExtractionStats]:
|
||||
"""
|
||||
This extraction will try to extract from all chunks that have not been kg-processed yet.
|
||||
|
||||
Approach:
|
||||
- Get all unprocessed connectors
|
||||
- For each connector:
|
||||
- Get all unprocessed documents
|
||||
- Classify each document to select proper ones
|
||||
- For each document:
|
||||
- Get all chunks
|
||||
- For each chunk:
|
||||
- Extract entities, relationships, and terms
|
||||
- make sure for each entity and relationship also the generalized versions are extracted!
|
||||
- Aggregate results as needed
|
||||
- Update Vespa and postgres
|
||||
"""
|
||||
|
||||
logger.info(f"Starting kg extraction for tenant {tenant_id}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
connector_ids = get_unprocessed_connector_ids(db_session)
|
||||
|
||||
connector_extraction_stats: list[ConnectorExtractionStats] = []
|
||||
document_kg_updates: Dict[str, KGUDocumentUpdateRequest] = {}
|
||||
|
||||
processing_chunks: list[KGChunkFormat] = []
|
||||
carryover_chunks: list[KGChunkFormat] = []
|
||||
connector_aggregated_kg_extractions_list: list[KGAggregatedExtractions] = []
|
||||
|
||||
document_classification_instructions = _get_classification_instructions()
|
||||
|
||||
for connector_id in connector_ids:
|
||||
connector_failed_chunk_extractions: list[KGChunkId] = []
|
||||
connector_succeeded_chunk_extractions: list[KGChunkId] = []
|
||||
connector_aggregated_kg_extractions: KGAggregatedExtractions = (
|
||||
KGAggregatedExtractions(
|
||||
grounded_entities_document_ids=defaultdict(str),
|
||||
entities=defaultdict(int),
|
||||
relationships=defaultdict(int),
|
||||
terms=defaultdict(int),
|
||||
)
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
unprocessed_documents = get_unprocessed_kg_documents_for_connector(
|
||||
db_session,
|
||||
connector_id,
|
||||
)
|
||||
|
||||
# TODO: restricted for testing only
|
||||
unprocessed_documents_list = list(unprocessed_documents)[:6]
|
||||
|
||||
document_classification_content_list = (
|
||||
get_document_classification_content_for_kg_processing(
|
||||
[
|
||||
unprocessed_document.id
|
||||
for unprocessed_document in unprocessed_documents_list
|
||||
],
|
||||
index_name,
|
||||
batch_size=processing_chunk_batch_size,
|
||||
)
|
||||
)
|
||||
|
||||
classification_outcomes: list[tuple[bool, KGClassificationDecisions]] = []
|
||||
|
||||
for document_classification_content in document_classification_content_list:
|
||||
classification_outcomes.extend(
|
||||
_kg_document_classification(
|
||||
document_classification_content,
|
||||
document_classification_instructions,
|
||||
)
|
||||
)
|
||||
|
||||
documents_to_process = []
|
||||
for document_to_process, document_classification_outcome in zip(
|
||||
unprocessed_documents_list, classification_outcomes
|
||||
):
|
||||
if (
|
||||
document_classification_outcome[0]
|
||||
and document_classification_outcome[1].classification_decision
|
||||
):
|
||||
documents_to_process.append(document_to_process)
|
||||
|
||||
for document_to_process in documents_to_process:
|
||||
formatted_chunk_batches = get_document_chunks_for_kg_processing(
|
||||
document_to_process.id,
|
||||
index_name,
|
||||
batch_size=processing_chunk_batch_size,
|
||||
)
|
||||
|
||||
formatted_chunk_batches_list = list(formatted_chunk_batches)
|
||||
|
||||
for formatted_chunk_batch in formatted_chunk_batches_list:
|
||||
processing_chunks.extend(formatted_chunk_batch)
|
||||
|
||||
if len(processing_chunks) >= processing_chunk_batch_size:
|
||||
carryover_chunks.extend(
|
||||
processing_chunks[processing_chunk_batch_size:]
|
||||
)
|
||||
processing_chunks = processing_chunks[:processing_chunk_batch_size]
|
||||
|
||||
chunk_processing_batch_results = _kg_chunk_batch_extraction(
|
||||
processing_chunks, index_name, tenant_id
|
||||
)
|
||||
|
||||
# Consider removing the stats expressions here and rather write to the db(?)
|
||||
connector_failed_chunk_extractions.extend(
|
||||
chunk_processing_batch_results.failed
|
||||
)
|
||||
connector_succeeded_chunk_extractions.extend(
|
||||
chunk_processing_batch_results.succeeded
|
||||
)
|
||||
|
||||
aggregated_batch_extractions = (
|
||||
chunk_processing_batch_results.aggregated_kg_extractions
|
||||
)
|
||||
# Update grounded_entities_document_ids (replace values)
|
||||
connector_aggregated_kg_extractions.grounded_entities_document_ids.update(
|
||||
aggregated_batch_extractions.grounded_entities_document_ids
|
||||
)
|
||||
# Add to entity counts instead of replacing
|
||||
for entity, count in aggregated_batch_extractions.entities.items():
|
||||
connector_aggregated_kg_extractions.entities[entity] += count
|
||||
# Add to relationship counts instead of replacing
|
||||
for (
|
||||
relationship,
|
||||
count,
|
||||
) in aggregated_batch_extractions.relationships.items():
|
||||
connector_aggregated_kg_extractions.relationships[
|
||||
relationship
|
||||
] += count
|
||||
# Add to term counts instead of replacing
|
||||
for term, count in aggregated_batch_extractions.terms.items():
|
||||
connector_aggregated_kg_extractions.terms[term] += count
|
||||
|
||||
connector_extraction_stats.append(
|
||||
ConnectorExtractionStats(
|
||||
connector_id=connector_id,
|
||||
num_failed=len(connector_failed_chunk_extractions),
|
||||
num_succeeded=len(connector_succeeded_chunk_extractions),
|
||||
num_processed=len(processing_chunks),
|
||||
)
|
||||
)
|
||||
|
||||
processing_chunks = carryover_chunks.copy()
|
||||
carryover_chunks = []
|
||||
|
||||
# processes remaining chunks
|
||||
chunk_processing_batch_results = _kg_chunk_batch_extraction(
|
||||
processing_chunks, index_name, tenant_id
|
||||
)
|
||||
|
||||
# Consider removing the stats expressions here and rather write to the db(?)
|
||||
connector_failed_chunk_extractions.extend(chunk_processing_batch_results.failed)
|
||||
connector_succeeded_chunk_extractions.extend(
|
||||
chunk_processing_batch_results.succeeded
|
||||
)
|
||||
|
||||
aggregated_batch_extractions = (
|
||||
chunk_processing_batch_results.aggregated_kg_extractions
|
||||
)
|
||||
# Update grounded_entities_document_ids (replace values)
|
||||
connector_aggregated_kg_extractions.grounded_entities_document_ids.update(
|
||||
aggregated_batch_extractions.grounded_entities_document_ids
|
||||
)
|
||||
# Add to entity counts instead of replacing
|
||||
for entity, count in aggregated_batch_extractions.entities.items():
|
||||
connector_aggregated_kg_extractions.entities[entity] += count
|
||||
# Add to relationship counts instead of replacing
|
||||
for relationship, count in aggregated_batch_extractions.relationships.items():
|
||||
connector_aggregated_kg_extractions.relationships[relationship] += count
|
||||
# Add to term counts instead of replacing
|
||||
for term, count in aggregated_batch_extractions.terms.items():
|
||||
connector_aggregated_kg_extractions.terms[term] += count
|
||||
|
||||
connector_extraction_stats.append(
|
||||
ConnectorExtractionStats(
|
||||
connector_id=connector_id,
|
||||
num_failed=len(connector_failed_chunk_extractions),
|
||||
num_succeeded=len(connector_succeeded_chunk_extractions),
|
||||
num_processed=len(processing_chunks),
|
||||
)
|
||||
)
|
||||
|
||||
processing_chunks = []
|
||||
carryover_chunks = []
|
||||
|
||||
connector_aggregated_kg_extractions_list.append(
|
||||
connector_aggregated_kg_extractions
|
||||
)
|
||||
|
||||
# aggregate document updates
|
||||
|
||||
for (
|
||||
processed_document
|
||||
) in (
|
||||
unprocessed_documents_list
|
||||
): # This will need to change if we do not materialize docs
|
||||
document_kg_updates[processed_document.id] = KGUDocumentUpdateRequest(
|
||||
document_id=processed_document.id,
|
||||
entities=set(),
|
||||
relationships=set(),
|
||||
terms=set(),
|
||||
)
|
||||
|
||||
updated_chunk_batches = get_document_chunks_for_kg_processing(
|
||||
processed_document.id,
|
||||
index_name,
|
||||
batch_size=processing_chunk_batch_size,
|
||||
)
|
||||
|
||||
for updated_chunk_batch in updated_chunk_batches:
|
||||
for updated_chunk in updated_chunk_batch:
|
||||
chunk_entities = updated_chunk.entities
|
||||
chunk_relationships = updated_chunk.relationships
|
||||
chunk_terms = updated_chunk.terms
|
||||
document_kg_updates[processed_document.id].entities.update(
|
||||
chunk_entities
|
||||
)
|
||||
document_kg_updates[processed_document.id].relationships.update(
|
||||
chunk_relationships
|
||||
)
|
||||
document_kg_updates[processed_document.id].terms.update(chunk_terms)
|
||||
|
||||
aggregated_kg_extractions = aggregate_kg_extractions(
|
||||
connector_aggregated_kg_extractions_list
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
tracked_entity_types = [
|
||||
x.id_name for x in get_entity_types(db_session, active=None)
|
||||
]
|
||||
|
||||
# Populate the KG database with the extracted entities, relationships, and terms
|
||||
for (
|
||||
entity,
|
||||
extraction_count,
|
||||
) in aggregated_kg_extractions.entities.items():
|
||||
if len(entity.split(":")) != 2:
|
||||
logger.error(
|
||||
f"Invalid entity {entity} in aggregated_kg_extractions.entities"
|
||||
)
|
||||
continue
|
||||
|
||||
entity_type, entity_name = entity.split(":")
|
||||
entity_type = entity_type.upper()
|
||||
entity_name = entity_name.capitalize()
|
||||
|
||||
if entity_type not in tracked_entity_types:
|
||||
continue
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if (
|
||||
entity
|
||||
not in aggregated_kg_extractions.grounded_entities_document_ids
|
||||
):
|
||||
add_entity(
|
||||
db_session=db_session,
|
||||
entity_type=entity_type,
|
||||
name=entity_name,
|
||||
cluster_count=extraction_count,
|
||||
)
|
||||
else:
|
||||
event_time = get_document_updated_at(
|
||||
entity,
|
||||
db_session,
|
||||
)
|
||||
add_entity(
|
||||
db_session=db_session,
|
||||
entity_type=entity_type,
|
||||
name=entity_name,
|
||||
cluster_count=extraction_count,
|
||||
document_id=aggregated_kg_extractions.grounded_entities_document_ids[
|
||||
entity
|
||||
],
|
||||
event_time=event_time,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding entity {entity} to the database: {e}")
|
||||
|
||||
relationship_type_counter: dict[str, int] = defaultdict(int)
|
||||
|
||||
for (
|
||||
relationship,
|
||||
extraction_count,
|
||||
) in aggregated_kg_extractions.relationships.items():
|
||||
relationship_split = relationship.split("__")
|
||||
|
||||
source_entity, relationship_type_, target_entity = relationship.split("__")
|
||||
source_entity = relationship_split[0]
|
||||
relationship_type = " ".join(relationship_split[1:-1]).replace("__", "_")
|
||||
target_entity = relationship_split[-1]
|
||||
|
||||
source_entity_type = source_entity.split(":")[0]
|
||||
target_entity_type = target_entity.split(":")[0]
|
||||
|
||||
if (
|
||||
source_entity_type not in tracked_entity_types
|
||||
or target_entity_type not in tracked_entity_types
|
||||
):
|
||||
continue
|
||||
|
||||
source_entity_general = f"{source_entity_type.upper()}"
|
||||
target_entity_general = f"{target_entity_type.upper()}"
|
||||
relationship_type_id_name = (
|
||||
f"{source_entity_general}__{relationship_type.lower()}__"
|
||||
f"{target_entity_general}"
|
||||
)
|
||||
|
||||
relationship_type_counter[relationship_type_id_name] += extraction_count
|
||||
|
||||
for (
|
||||
relationship_type_id_name,
|
||||
extraction_count,
|
||||
) in relationship_type_counter.items():
|
||||
(
|
||||
source_entity_type,
|
||||
relationship_type,
|
||||
target_entity_type,
|
||||
) = relationship_type_id_name.split("__")
|
||||
|
||||
if (
|
||||
source_entity_type not in tracked_entity_types
|
||||
or target_entity_type not in tracked_entity_types
|
||||
):
|
||||
continue
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
add_relationship_type(
|
||||
db_session=db_session,
|
||||
source_entity_type=source_entity_type.upper(),
|
||||
relationship_type=relationship_type,
|
||||
target_entity_type=target_entity_type.upper(),
|
||||
definition=False,
|
||||
extraction_count=extraction_count,
|
||||
)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error adding relationship type {relationship_type_id_name} to the database: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error adding relationship type {relationship_type_id_name} to the database: {e}"
|
||||
)
|
||||
|
||||
for (
|
||||
relationship,
|
||||
extraction_count,
|
||||
) in aggregated_kg_extractions.relationships.items():
|
||||
relationship_split = relationship.split("__")
|
||||
|
||||
source_entity, relationship_type_, target_entity = relationship.split("__")
|
||||
source_entity = relationship_split[0]
|
||||
relationship_type = (
|
||||
" ".join(relationship_split[1:-1]).replace("__", " ").replace("_", " ")
|
||||
)
|
||||
target_entity = relationship_split[-1]
|
||||
|
||||
source_entity_type = source_entity.split(":")[0]
|
||||
target_entity_type = target_entity.split(":")[0]
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
add_relationship(db_session, relationship, extraction_count)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error adding relationship {relationship} to the database: {e}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
add_or_increment_relationship(db_session, relationship)
|
||||
db_session.commit()
|
||||
|
||||
# Populate the Documents table with the kg information for the documents
|
||||
|
||||
for document_id, document_kg_update in document_kg_updates.items():
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
update_document_kg_info(
|
||||
db_session,
|
||||
document_id,
|
||||
kg_processed=True,
|
||||
kg_data={
|
||||
"entities": list(document_kg_update.entities),
|
||||
"relationships": list(document_kg_update.relationships),
|
||||
"terms": list(document_kg_update.terms),
|
||||
},
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
return connector_extraction_stats
|
||||
|
||||
|
||||
def _kg_chunk_batch_extraction(
|
||||
chunks: list[KGChunkFormat],
|
||||
index_name: str,
|
||||
tenant_id: str,
|
||||
) -> KGBatchExtractionStats:
|
||||
_, fast_llm = get_default_llms()
|
||||
|
||||
succeeded_chunk_id: list[KGChunkId] = []
|
||||
failed_chunk_id: list[KGChunkId] = []
|
||||
succeeded_chunk_extraction: list[KGChunkExtraction] = []
|
||||
|
||||
preformatted_prompt = MASTER_EXTRACTION_PROMPT.format(
|
||||
entity_types=get_entity_types_str(active=True)
|
||||
)
|
||||
|
||||
def process_single_chunk(
|
||||
chunk: KGChunkFormat, preformatted_prompt: str
|
||||
) -> tuple[bool, KGUChunkUpdateRequest]:
|
||||
"""Process a single chunk and return success status and chunk ID."""
|
||||
|
||||
# For now, we're just processing the content
|
||||
# TODO: Implement actual prompt application logic
|
||||
|
||||
llm_preprocessing = prepare_llm_content(chunk)
|
||||
|
||||
formatted_prompt = preformatted_prompt.replace(
|
||||
"---content---", llm_preprocessing.llm_context
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=formatted_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"LLM Extraction from chunk {chunk.chunk_id} from doc {chunk.document_id}"
|
||||
)
|
||||
raw_extraction_result = fast_llm.invoke(msg)
|
||||
extraction_result = message_to_string(raw_extraction_result)
|
||||
|
||||
try:
|
||||
cleaned_result = (
|
||||
extraction_result.replace("```json", "").replace("```", "").strip()
|
||||
)
|
||||
parsed_result = json.loads(cleaned_result)
|
||||
extracted_entities = parsed_result.get("entities", [])
|
||||
extracted_relationships = parsed_result.get("relationships", [])
|
||||
extracted_terms = parsed_result.get("terms", [])
|
||||
|
||||
implied_extracted_relationships = [
|
||||
llm_preprocessing.core_entity
|
||||
+ "__"
|
||||
+ "relates_neutrally_to"
|
||||
+ "__"
|
||||
+ entity
|
||||
for entity in extracted_entities
|
||||
]
|
||||
|
||||
all_entities = set(
|
||||
list(extracted_entities)
|
||||
+ list(llm_preprocessing.implied_entities)
|
||||
+ list(
|
||||
generalize_entities(
|
||||
extracted_entities + llm_preprocessing.implied_entities
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"All entities: {all_entities}")
|
||||
|
||||
all_relationships = (
|
||||
extracted_relationships
|
||||
+ llm_preprocessing.implied_relationships
|
||||
+ implied_extracted_relationships
|
||||
)
|
||||
all_relationships = set(
|
||||
list(all_relationships)
|
||||
+ list(generalize_relationships(all_relationships))
|
||||
)
|
||||
|
||||
kg_updates = [
|
||||
KGUChunkUpdateRequest(
|
||||
document_id=chunk.document_id,
|
||||
chunk_id=chunk.chunk_id,
|
||||
core_entity=llm_preprocessing.core_entity,
|
||||
entities=all_entities,
|
||||
relationships=all_relationships,
|
||||
terms=set(extracted_terms),
|
||||
),
|
||||
]
|
||||
|
||||
update_kg_chunks_vespa_info(
|
||||
kg_update_requests=kg_updates,
|
||||
index_name=index_name,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"KG updated: {chunk.chunk_id} from doc {chunk.document_id}"
|
||||
)
|
||||
|
||||
return True, kg_updates[0] # only single chunk
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(
|
||||
f"Invalid JSON format for extraction of chunk {chunk.chunk_id} \
|
||||
from doc {chunk.document_id}: {str(e)}"
|
||||
)
|
||||
logger.error(f"Raw output: {extraction_result}")
|
||||
return False, KGUChunkUpdateRequest(
|
||||
document_id=chunk.document_id,
|
||||
chunk_id=chunk.chunk_id,
|
||||
core_entity=llm_preprocessing.core_entity,
|
||||
entities=set(),
|
||||
relationships=set(),
|
||||
terms=set(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process chunk {chunk.chunk_id} from doc {chunk.document_id}: {str(e)}"
|
||||
)
|
||||
return False, KGUChunkUpdateRequest(
|
||||
document_id=chunk.document_id,
|
||||
chunk_id=chunk.chunk_id,
|
||||
core_entity="",
|
||||
entities=set(),
|
||||
relationships=set(),
|
||||
terms=set(),
|
||||
)
|
||||
|
||||
# Assume for prototype: use_threads = True. TODO: Make thread safe!
|
||||
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(process_single_chunk, (chunk, preformatted_prompt)) for chunk in chunks
|
||||
]
|
||||
|
||||
logger.debug("Running KG extraction on chunks in parallel")
|
||||
results = run_functions_tuples_in_parallel(functions_with_args, allow_failures=True)
|
||||
|
||||
# Sort results into succeeded and failed
|
||||
for success, chunk_results in results:
|
||||
if success:
|
||||
succeeded_chunk_id.append(
|
||||
KGChunkId(
|
||||
document_id=chunk_results.document_id,
|
||||
chunk_id=chunk_results.chunk_id,
|
||||
)
|
||||
)
|
||||
succeeded_chunk_extraction.append(chunk_results)
|
||||
else:
|
||||
failed_chunk_id.append(
|
||||
KGChunkId(
|
||||
document_id=chunk_results.document_id,
|
||||
chunk_id=chunk_results.chunk_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Collect data for postgres later on
|
||||
|
||||
aggregated_kg_extractions = KGAggregatedExtractions(
|
||||
grounded_entities_document_ids=defaultdict(str),
|
||||
entities=defaultdict(int),
|
||||
relationships=defaultdict(int),
|
||||
terms=defaultdict(int),
|
||||
)
|
||||
|
||||
for chunk_result in succeeded_chunk_extraction:
|
||||
aggregated_kg_extractions.grounded_entities_document_ids[
|
||||
chunk_result.core_entity
|
||||
] = chunk_result.document_id
|
||||
|
||||
mentioned_chunk_entities: set[str] = set()
|
||||
for relationship in chunk_result.relationships:
|
||||
relationship_split = relationship.split("__")
|
||||
if len(relationship_split) == 3:
|
||||
if relationship_split[0] not in mentioned_chunk_entities:
|
||||
aggregated_kg_extractions.entities[relationship_split[0]] += 1
|
||||
mentioned_chunk_entities.add(relationship_split[0])
|
||||
if relationship_split[2] not in mentioned_chunk_entities:
|
||||
aggregated_kg_extractions.entities[relationship_split[2]] += 1
|
||||
mentioned_chunk_entities.add(relationship_split[2])
|
||||
aggregated_kg_extractions.relationships[relationship] += 1
|
||||
|
||||
for kg_entity in chunk_result.entities:
|
||||
if kg_entity not in mentioned_chunk_entities:
|
||||
aggregated_kg_extractions.entities[kg_entity] += 1
|
||||
mentioned_chunk_entities.add(kg_entity)
|
||||
|
||||
for kg_term in chunk_result.terms:
|
||||
aggregated_kg_extractions.terms[kg_term] += 1
|
||||
|
||||
return KGBatchExtractionStats(
|
||||
connector_id=chunks[0].connector_id if chunks else None, # TODO: Update!
|
||||
succeeded=succeeded_chunk_id,
|
||||
failed=failed_chunk_id,
|
||||
aggregated_kg_extractions=aggregated_kg_extractions,
|
||||
)
|
||||
|
||||
|
||||
def _kg_document_classification(
|
||||
document_classification_content_list: list[KGClassificationContent],
|
||||
classification_instructions: dict[str, KGClassificationInstructionStrings],
|
||||
) -> list[tuple[bool, KGClassificationDecisions]]:
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
|
||||
def classify_single_document(
|
||||
document_classification_content: KGClassificationContent,
|
||||
classification_instructions: dict[str, KGClassificationInstructionStrings],
|
||||
) -> tuple[bool, KGClassificationDecisions]:
|
||||
"""Classify a single document whether it should be kg-processed or not"""
|
||||
|
||||
source = document_classification_content.source_type
|
||||
document_id = document_classification_content.document_id
|
||||
|
||||
if source not in classification_instructions:
|
||||
logger.info(
|
||||
f"Source {source} did not have kg classification instructions. No content analysis."
|
||||
)
|
||||
return False, KGClassificationDecisions(
|
||||
document_id=document_id,
|
||||
classification_decision=False,
|
||||
classification_class=None,
|
||||
)
|
||||
|
||||
classification_prompt = prepare_llm_document_content(
|
||||
document_classification_content,
|
||||
category_list=classification_instructions[source].classification_options,
|
||||
category_definitions=classification_instructions[
|
||||
source
|
||||
].classification_class_definitions,
|
||||
)
|
||||
|
||||
if classification_prompt.llm_prompt is None:
|
||||
logger.info(
|
||||
f"Source {source} did not have kg document classification instructions. No content analysis."
|
||||
)
|
||||
return False, KGClassificationDecisions(
|
||||
document_id=document_id,
|
||||
classification_decision=False,
|
||||
classification_class=None,
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=classification_prompt.llm_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"LLM Classification from document {document_classification_content.document_id}"
|
||||
)
|
||||
raw_classification_result = primary_llm.invoke(msg)
|
||||
classification_result = (
|
||||
message_to_string(raw_classification_result)
|
||||
.replace("```json", "")
|
||||
.replace("```", "")
|
||||
.strip()
|
||||
)
|
||||
|
||||
classification_class = classification_result.split("CATEGORY:")[1].strip()
|
||||
|
||||
if (
|
||||
classification_class
|
||||
in classification_instructions[source].classification_class_definitions
|
||||
):
|
||||
extraction_decision = cast(
|
||||
bool,
|
||||
classification_instructions[
|
||||
source
|
||||
].classification_class_definitions[classification_class][
|
||||
"extraction"
|
||||
],
|
||||
)
|
||||
else:
|
||||
extraction_decision = False
|
||||
|
||||
return True, KGClassificationDecisions(
|
||||
document_id=document_id,
|
||||
classification_decision=extraction_decision,
|
||||
classification_class=classification_class,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to classify document {document_classification_content.document_id}: {str(e)}"
|
||||
)
|
||||
return False, KGClassificationDecisions(
|
||||
document_id=document_id,
|
||||
classification_decision=False,
|
||||
classification_class=None,
|
||||
)
|
||||
|
||||
# Assume for prototype: use_threads = True. TODO: Make thread safe!
|
||||
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(
|
||||
classify_single_document,
|
||||
(document_classification_content, classification_instructions),
|
||||
)
|
||||
for document_classification_content in document_classification_content_list
|
||||
]
|
||||
|
||||
logger.debug("Running KG classification on documents in parallel")
|
||||
results = run_functions_tuples_in_parallel(functions_with_args, allow_failures=True)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# logger.debug("Running KG extraction on chunks in parallel")
|
||||
# results = run_functions_tuples_in_parallel(functions_with_args, allow_failures=True)
|
||||
|
||||
# # Sort results into succeeded and failed
|
||||
# for success, chunk_results in results:
|
||||
# if success:
|
||||
# succeeded_chunk_id.append(
|
||||
# KGChunkId(
|
||||
# document_id=chunk_results.document_id,
|
||||
# chunk_id=chunk_results.chunk_id,
|
||||
# )
|
||||
# )
|
||||
# succeeded_chunk_extraction.append(chunk_results)
|
||||
# else:
|
||||
# failed_chunk_id.append(
|
||||
# KGChunkId(
|
||||
# document_id=chunk_results.document_id,
|
||||
# chunk_id=chunk_results.chunk_id,
|
||||
# )
|
||||
# )
|
||||
|
||||
# # Collect data for postgres later on
|
||||
|
||||
# aggregated_kg_extractions = KGAggregatedExtractions(
|
||||
# grounded_entities_document_ids=defaultdict(str),
|
||||
# entities=defaultdict(int),
|
||||
# relationships=defaultdict(int),
|
||||
# terms=defaultdict(int),
|
||||
# )
|
||||
|
||||
# for chunk_result in succeeded_chunk_extraction:
|
||||
# aggregated_kg_extractions.grounded_entities_document_ids[
|
||||
# chunk_result.core_entity
|
||||
# ] = chunk_result.document_id
|
||||
|
||||
# mentioned_chunk_entities: set[str] = set()
|
||||
# for relationship in chunk_result.relationships:
|
||||
# relationship_split = relationship.split("__")
|
||||
# if len(relationship_split) == 3:
|
||||
# if relationship_split[0] not in mentioned_chunk_entities:
|
||||
# aggregated_kg_extractions.entities[relationship_split[0]] += 1
|
||||
# mentioned_chunk_entities.add(relationship_split[0])
|
||||
# if relationship_split[2] not in mentioned_chunk_entities:
|
||||
# aggregated_kg_extractions.entities[relationship_split[2]] += 1
|
||||
# mentioned_chunk_entities.add(relationship_split[2])
|
||||
# aggregated_kg_extractions.relationships[relationship] += 1
|
||||
|
||||
# for kg_entity in chunk_result.entities:
|
||||
# if kg_entity not in mentioned_chunk_entities:
|
||||
# aggregated_kg_extractions.entities[kg_entity] += 1
|
||||
# mentioned_chunk_entities.add(kg_entity)
|
||||
|
||||
# for kg_term in chunk_result.terms:
|
||||
# aggregated_kg_extractions.terms[kg_term] += 1
|
||||
|
||||
# return KGBatchExtractionStats(
|
||||
# connector_id=chunks[0].connector_id if chunks else None, # TODO: Update!
|
||||
# succeeded=succeeded_chunk_id,
|
||||
# failed=failed_chunk_id,
|
||||
# aggregated_kg_extractions=aggregated_kg_extractions,
|
||||
# )
|
||||
|
||||
|
||||
def _kg_connector_extraction(
|
||||
connector_id: str,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
logger.info(
|
||||
f"Starting kg extraction for connector {connector_id} for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
# - grab kg type data from postgres
|
||||
|
||||
# - construct prompt
|
||||
|
||||
# find all documents for the connector that have not been kg-processed
|
||||
|
||||
# - loop for :
|
||||
|
||||
# - grab a number of chunks from vespa
|
||||
|
||||
# - convert them into the KGChunk format
|
||||
|
||||
# - run the extractions in parallel
|
||||
|
||||
# - save the results
|
||||
|
||||
# - mark chunks as processed
|
||||
|
||||
# - update the connector status
|
||||
|
||||
|
||||
#
|
||||
@@ -1,86 +0,0 @@
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
"""
|
||||
def kg_extraction_section(
|
||||
kg_chunks: list[KGChunk],
|
||||
llm: LLM,
|
||||
max_workers: int = 10,
|
||||
) -> list[KGChunkExtraction]:
|
||||
def _get_metadata_str(metadata: dict[str, str | list[str]]) -> str:
|
||||
metadata_str = "\nMetadata:\n"
|
||||
for key, value in metadata.items():
|
||||
value_str = ", ".join(value) if isinstance(value, list) else value
|
||||
metadata_str += f"{key} - {value_str}\n"
|
||||
return metadata_str
|
||||
|
||||
def _get_usefulness_messages() -> list[dict[str, str]]:
|
||||
metadata_str = _get_metadata_str(metadata) if metadata else ""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SECTION_FILTER_PROMPT.format(
|
||||
title=title.replace("\n", " "),
|
||||
chunk_text=section_content,
|
||||
user_query=query,
|
||||
optional_metadata=metadata_str,
|
||||
),
|
||||
},
|
||||
]
|
||||
return messages
|
||||
|
||||
def _extract_usefulness(model_output: str) -> bool:
|
||||
if model_output.strip().strip('"').lower() == NONUSEFUL_PAT.lower():
|
||||
return False
|
||||
return True
|
||||
|
||||
messages = _get_usefulness_messages()
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = message_to_string(llm.invoke(filled_llm_prompt))
|
||||
logger.debug(model_output)
|
||||
|
||||
return _extract_usefulness(model_output)
|
||||
|
||||
"""
|
||||
"""
|
||||
def llm_batch_eval_sections(
|
||||
query: str,
|
||||
section_contents: list[str],
|
||||
llm: LLM,
|
||||
titles: list[str],
|
||||
metadata_list: list[dict[str, str | list[str]]],
|
||||
use_threads: bool = True,
|
||||
) -> list[bool]:
|
||||
if DISABLE_LLM_DOC_RELEVANCE:
|
||||
raise RuntimeError(
|
||||
"LLM Doc Relevance is globally disabled, "
|
||||
"this should have been caught upstream."
|
||||
)
|
||||
|
||||
if use_threads:
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(llm_eval_section, (query, section_content, llm, title, metadata))
|
||||
for section_content, title, metadata in zip(
|
||||
section_contents, titles, metadata_list
|
||||
)
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
"Running LLM usefulness eval in parallel (following logging may be out of order)"
|
||||
)
|
||||
parallel_results = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=True
|
||||
)
|
||||
|
||||
# In case of failure/timeout, don't throw out the section
|
||||
return [True if item is None else item for item in parallel_results]
|
||||
|
||||
else:
|
||||
return [
|
||||
llm_eval_section(query, section_content, llm, title, metadata)
|
||||
for section_content, title, metadata in zip(
|
||||
section_contents, titles, metadata_list
|
||||
)
|
||||
]
|
||||
"""
|
||||
@@ -1,99 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class KGChunkFormat(BaseModel):
|
||||
connector_id: int | None = None
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
title: str
|
||||
content: str
|
||||
primary_owners: list[str]
|
||||
secondary_owners: list[str]
|
||||
source_type: str
|
||||
metadata: dict[str, str | list[str]] | None = None
|
||||
entities: Dict[str, int] = {}
|
||||
relationships: Dict[str, int] = {}
|
||||
terms: Dict[str, int] = {}
|
||||
|
||||
|
||||
class KGChunkExtraction(BaseModel):
|
||||
connector_id: int
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
core_entity: str
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
terms: list[str]
|
||||
|
||||
|
||||
class KGChunkId(BaseModel):
|
||||
connector_id: int | None = None
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
|
||||
|
||||
class KGAggregatedExtractions(BaseModel):
|
||||
grounded_entities_document_ids: defaultdict[str, str]
|
||||
entities: defaultdict[str, int]
|
||||
relationships: defaultdict[str, int]
|
||||
terms: defaultdict[str, int]
|
||||
|
||||
|
||||
class KGBatchExtractionStats(BaseModel):
|
||||
connector_id: int | None = None
|
||||
succeeded: list[KGChunkId]
|
||||
failed: list[KGChunkId]
|
||||
aggregated_kg_extractions: KGAggregatedExtractions
|
||||
|
||||
|
||||
class ConnectorExtractionStats(BaseModel):
|
||||
connector_id: int
|
||||
num_succeeded: int
|
||||
num_failed: int
|
||||
num_processed: int
|
||||
|
||||
|
||||
class KGPerson(BaseModel):
|
||||
name: str
|
||||
company: str
|
||||
employee: bool
|
||||
|
||||
|
||||
class NormalizedEntities(BaseModel):
|
||||
entities: list[str]
|
||||
entity_normalization_map: dict[str, str | None]
|
||||
|
||||
|
||||
class NormalizedRelationships(BaseModel):
|
||||
relationships: list[str]
|
||||
relationship_normalization_map: dict[str, str | None]
|
||||
|
||||
|
||||
class NormalizedTerms(BaseModel):
|
||||
terms: list[str]
|
||||
term_normalization_map: dict[str, str | None]
|
||||
|
||||
|
||||
class KGClassificationContent(BaseModel):
|
||||
document_id: str
|
||||
classification_content: str
|
||||
source_type: str
|
||||
|
||||
|
||||
class KGClassificationDecisions(BaseModel):
|
||||
document_id: str
|
||||
classification_decision: bool
|
||||
classification_class: str | None
|
||||
|
||||
|
||||
class KGClassificationRule(BaseModel):
|
||||
description: str
|
||||
extration: bool
|
||||
|
||||
|
||||
class KGClassificationInstructionStrings(BaseModel):
|
||||
classification_options: str
|
||||
classification_class_definitions: dict[str, Dict[str, str | bool]]
|
||||
@@ -1,55 +0,0 @@
|
||||
from typing import Dict
|
||||
|
||||
from onyx.kg.context_preparations_extraction.fireflies import (
|
||||
prepare_llm_content_fireflies,
|
||||
)
|
||||
from onyx.kg.context_preparations_extraction.fireflies import (
|
||||
prepare_llm_document_content_fireflies,
|
||||
)
|
||||
from onyx.kg.context_preparations_extraction.models import ContextPreparation
|
||||
from onyx.kg.context_preparations_extraction.models import (
|
||||
KGDocumentClassificationPrompt,
|
||||
)
|
||||
from onyx.kg.models import KGChunkFormat
|
||||
from onyx.kg.models import KGClassificationContent
|
||||
|
||||
|
||||
def prepare_llm_content(chunk: KGChunkFormat) -> ContextPreparation:
|
||||
"""
|
||||
Prepare the content for the LLM.
|
||||
"""
|
||||
if chunk.source_type == "fireflies":
|
||||
return prepare_llm_content_fireflies(chunk)
|
||||
|
||||
else:
|
||||
return ContextPreparation(
|
||||
llm_context=chunk.content,
|
||||
core_entity="",
|
||||
implied_entities=[],
|
||||
implied_relationships=[],
|
||||
implied_terms=[],
|
||||
)
|
||||
|
||||
|
||||
def prepare_llm_document_content(
|
||||
document_classification_content: KGClassificationContent,
|
||||
category_list: str,
|
||||
category_definitions: dict[str, Dict[str, str | bool]],
|
||||
) -> KGDocumentClassificationPrompt:
|
||||
"""
|
||||
Prepare the content for the extraction classification.
|
||||
"""
|
||||
|
||||
category_definition_string = ""
|
||||
for category, category_data in category_definitions.items():
|
||||
category_definition_string += f"{category}: {category_data['description']}\n"
|
||||
|
||||
if document_classification_content.source_type == "fireflies":
|
||||
return prepare_llm_document_content_fireflies(
|
||||
document_classification_content, category_list, category_definition_string
|
||||
)
|
||||
|
||||
else:
|
||||
return KGDocumentClassificationPrompt(
|
||||
llm_prompt=None,
|
||||
)
|
||||
@@ -1,41 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_unprocessed_connector_ids(tenant_id: str) -> list[int]:
|
||||
"""
|
||||
Retrieves a list of connector IDs that have not been KG processed for a given tenant.
|
||||
|
||||
Args:
|
||||
tenant_id (str): The ID of the tenant to check for unprocessed connectors
|
||||
|
||||
Returns:
|
||||
list[int]: List of connector IDs that have enabled KG extraction but have unprocessed documents
|
||||
"""
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Find connectors that:
|
||||
# 1. Have KG extraction enabled
|
||||
# 2. Have documents that haven't been KG processed
|
||||
stmt = (
|
||||
select(Connector.id)
|
||||
.distinct()
|
||||
.join(DocumentByConnectorCredentialPair)
|
||||
.where(
|
||||
Connector.kg_extraction_enabled,
|
||||
DocumentByConnectorCredentialPair.has_been_kg_processed.is_(False),
|
||||
)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
return [row[0] for row in result.fetchall()]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching unprocessed connector IDs: {str(e)}")
|
||||
return []
|
||||
@@ -1,23 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbedTextType
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
|
||||
def encode_string_batch(strings: List[str]) -> np.ndarray:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=current_search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
# Get embeddings while session is still open
|
||||
embedding = model.encode(strings, text_type=EmbedTextType.QUERY)
|
||||
return np.array(embedding)
|
||||
@@ -1,123 +0,0 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from onyx.configs.kg_configs import KG_OWN_COMPANY
|
||||
from onyx.configs.kg_configs import KG_OWN_EMAIL_DOMAINS
|
||||
from onyx.kg.models import KGAggregatedExtractions
|
||||
from onyx.kg.models import KGPerson
|
||||
|
||||
|
||||
def format_entity(entity: str) -> str:
|
||||
if len(entity.split(":")) == 2:
|
||||
entity_type, entity_name = entity.split(":")
|
||||
return f"{entity_type.upper()}:{entity_name.title()}"
|
||||
else:
|
||||
return entity
|
||||
|
||||
|
||||
def format_relationship(relationship: str) -> str:
|
||||
source_node, relationship_type, target_node = relationship.split("__")
|
||||
return (
|
||||
f"{format_entity(source_node)}__"
|
||||
f"{relationship_type.lower()}__"
|
||||
f"{format_entity(target_node)}"
|
||||
)
|
||||
|
||||
|
||||
def format_relationship_type(relationship_type: str) -> str:
|
||||
source_node_type, relationship_type, target_node_type = relationship_type.split(
|
||||
"__"
|
||||
)
|
||||
return (
|
||||
f"{source_node_type.upper()}__"
|
||||
f"{relationship_type.lower()}__"
|
||||
f"{target_node_type.upper()}"
|
||||
)
|
||||
|
||||
|
||||
def generate_relationship_type(relationship: str) -> str:
|
||||
source_node, relationship_type, target_node = relationship.split("__")
|
||||
return (
|
||||
f"{source_node.split(':')[0].upper()}__"
|
||||
f"{relationship_type.lower()}__"
|
||||
f"{target_node.split(':')[0].upper()}"
|
||||
)
|
||||
|
||||
|
||||
def aggregate_kg_extractions(
|
||||
connector_aggregated_kg_extractions_list: list[KGAggregatedExtractions],
|
||||
) -> KGAggregatedExtractions:
|
||||
aggregated_kg_extractions = KGAggregatedExtractions(
|
||||
grounded_entities_document_ids=defaultdict(str),
|
||||
entities=defaultdict(int),
|
||||
relationships=defaultdict(int),
|
||||
terms=defaultdict(int),
|
||||
)
|
||||
|
||||
for connector_aggregated_kg_extractions in connector_aggregated_kg_extractions_list:
|
||||
for (
|
||||
grounded_entity,
|
||||
document_id,
|
||||
) in connector_aggregated_kg_extractions.grounded_entities_document_ids.items():
|
||||
aggregated_kg_extractions.grounded_entities_document_ids[
|
||||
grounded_entity
|
||||
] = document_id
|
||||
|
||||
for entity, count in connector_aggregated_kg_extractions.entities.items():
|
||||
aggregated_kg_extractions.entities[entity] += count
|
||||
for (
|
||||
relationship,
|
||||
count,
|
||||
) in connector_aggregated_kg_extractions.relationships.items():
|
||||
aggregated_kg_extractions.relationships[relationship] += count
|
||||
for term, count in connector_aggregated_kg_extractions.terms.items():
|
||||
aggregated_kg_extractions.terms[term] += count
|
||||
return aggregated_kg_extractions
|
||||
|
||||
|
||||
def kg_email_processing(email: str) -> KGPerson:
|
||||
"""
|
||||
Process the email.
|
||||
"""
|
||||
name, company_domain = email.split("@")
|
||||
|
||||
assert isinstance(KG_OWN_EMAIL_DOMAINS, list)
|
||||
|
||||
employee = any(domain in company_domain for domain in KG_OWN_EMAIL_DOMAINS)
|
||||
if employee:
|
||||
company = KG_OWN_COMPANY
|
||||
else:
|
||||
company = company_domain.capitalize()
|
||||
|
||||
return KGPerson(name=name, company=company, employee=employee)
|
||||
|
||||
|
||||
def generalize_entities(entities: list[str]) -> set[str]:
|
||||
"""
|
||||
Generalize entities to their superclass.
|
||||
"""
|
||||
return set([f"{entity.split(':')[0]}:*" for entity in entities])
|
||||
|
||||
|
||||
def generalize_relationships(relationships: list[str]) -> set[str]:
|
||||
"""
|
||||
Generalize relationships to their superclass.
|
||||
"""
|
||||
generalized_relationships: set[str] = set()
|
||||
for relationship in relationships:
|
||||
assert (
|
||||
len(relationship.split("__")) == 3
|
||||
), "Relationship is not in the correct format"
|
||||
source_entity, relationship_type, target_entity = relationship.split("__")
|
||||
generalized_source_entity = list(generalize_entities([source_entity]))[0]
|
||||
generalized_target_entity = list(generalize_entities([target_entity]))[0]
|
||||
generalized_relationships.add(
|
||||
f"{generalized_source_entity}__{relationship_type}__{target_entity}"
|
||||
)
|
||||
generalized_relationships.add(
|
||||
f"{source_entity}__{relationship_type}__{generalized_target_entity}"
|
||||
)
|
||||
generalized_relationships.add(
|
||||
f"{generalized_source_entity}__{relationship_type}__{generalized_target_entity}"
|
||||
)
|
||||
|
||||
return generalized_relationships
|
||||
@@ -1,179 +0,0 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
|
||||
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
|
||||
from onyx.document_index.vespa.index import IndexFilters
|
||||
from onyx.kg.context_preparations_extraction.fireflies import (
|
||||
get_classification_content_from_fireflies_chunks,
|
||||
)
|
||||
from onyx.kg.models import KGChunkFormat
|
||||
from onyx.kg.models import KGClassificationContent
|
||||
|
||||
|
||||
def get_document_classification_content_for_kg_processing(
|
||||
document_ids: list[str],
|
||||
index_name: str,
|
||||
batch_size: int = 8,
|
||||
num_classification_chunks: int = 3,
|
||||
) -> Generator[list[KGClassificationContent], None, None]:
|
||||
"""
|
||||
Generates the content used for initial classification of a document from
|
||||
the first num_classification_chunks chunks.
|
||||
"""
|
||||
|
||||
classification_content_list: list[KGClassificationContent] = []
|
||||
|
||||
for i in range(0, len(document_ids), batch_size):
|
||||
batch_document_ids = document_ids[i : i + batch_size]
|
||||
for document_id in batch_document_ids:
|
||||
# ... existing code for getting chunks and processing ...
|
||||
first_num_classification_chunks: list[dict] = _get_chunks_via_visit_api(
|
||||
chunk_request=VespaChunkRequest(
|
||||
document_id=document_id,
|
||||
max_chunk_ind=num_classification_chunks - 1,
|
||||
min_chunk_ind=0,
|
||||
),
|
||||
index_name=index_name,
|
||||
filters=IndexFilters(access_control_list=None),
|
||||
field_names=[
|
||||
"document_id",
|
||||
"chunk_id",
|
||||
"title",
|
||||
"content",
|
||||
"metadata",
|
||||
"source_type",
|
||||
"primary_owners",
|
||||
"secondary_owners",
|
||||
],
|
||||
get_large_chunks=False,
|
||||
)
|
||||
|
||||
first_num_classification_chunks = sorted(
|
||||
first_num_classification_chunks, key=lambda x: x["fields"]["chunk_id"]
|
||||
)[:num_classification_chunks]
|
||||
|
||||
classification_content = _get_classification_content_from_chunks(
|
||||
first_num_classification_chunks
|
||||
)
|
||||
|
||||
classification_content_list.append(
|
||||
KGClassificationContent(
|
||||
document_id=document_id,
|
||||
classification_content=classification_content,
|
||||
source_type=first_num_classification_chunks[0]["fields"][
|
||||
"source_type"
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# Yield the batch of classification content
|
||||
if classification_content_list:
|
||||
yield classification_content_list
|
||||
classification_content_list = []
|
||||
|
||||
# Yield any remaining items
|
||||
if classification_content_list:
|
||||
yield classification_content_list
|
||||
|
||||
|
||||
def get_document_chunks_for_kg_processing(
|
||||
document_id: str,
|
||||
index_name: str,
|
||||
batch_size: int = 8,
|
||||
) -> Generator[list[KGChunkFormat], None, None]:
|
||||
"""
|
||||
Retrieves chunks from Vespa for the given document IDs and converts them to KGChunks.
|
||||
|
||||
Args:
|
||||
document_ids (list[str]): List of document IDs to fetch chunks for
|
||||
index_name (str): Name of the Vespa index
|
||||
tenant_id (str): ID of the tenant
|
||||
|
||||
Yields:
|
||||
list[KGChunk]: Batches of chunks ready for KG processing
|
||||
"""
|
||||
|
||||
current_batch: list[KGChunkFormat] = []
|
||||
|
||||
# get all chunks for the document
|
||||
chunks = _get_chunks_via_visit_api(
|
||||
chunk_request=VespaChunkRequest(document_id=document_id),
|
||||
index_name=index_name,
|
||||
filters=IndexFilters(access_control_list=None),
|
||||
field_names=[
|
||||
"document_id",
|
||||
"chunk_id",
|
||||
"title",
|
||||
"content",
|
||||
"metadata",
|
||||
"primary_owners",
|
||||
"secondary_owners",
|
||||
"source_type",
|
||||
"kg_entities",
|
||||
"kg_relationships",
|
||||
"kg_terms",
|
||||
],
|
||||
get_large_chunks=False,
|
||||
)
|
||||
|
||||
# Convert Vespa chunks to KGChunks
|
||||
# kg_chunks: list[KGChunkFormat] = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
fields = chunk["fields"]
|
||||
if isinstance(fields.get("metadata", {}), str):
|
||||
fields["metadata"] = json.loads(fields["metadata"])
|
||||
current_batch.append(
|
||||
KGChunkFormat(
|
||||
connector_id=None, # We may need to adjust this
|
||||
document_id=fields.get("document_id"),
|
||||
chunk_id=fields.get("chunk_id"),
|
||||
primary_owners=fields.get("primary_owners", []),
|
||||
secondary_owners=fields.get("secondary_owners", []),
|
||||
source_type=fields.get("source_type", ""),
|
||||
title=fields.get("title", ""),
|
||||
content=fields.get("content", ""),
|
||||
metadata=fields.get("metadata", {}),
|
||||
entities=fields.get("kg_entities", {}),
|
||||
relationships=fields.get("kg_relationships", {}),
|
||||
terms=fields.get("kg_terms", {}),
|
||||
)
|
||||
)
|
||||
|
||||
if len(current_batch) >= batch_size:
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
|
||||
# Yield any remaining chunks
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
|
||||
def _get_classification_content_from_chunks(
|
||||
first_num_classification_chunks: list[dict],
|
||||
) -> str:
|
||||
"""
|
||||
Creates a KGClassificationContent object from a list of chunks.
|
||||
"""
|
||||
|
||||
source_type = first_num_classification_chunks[0]["fields"]["source_type"]
|
||||
|
||||
if source_type == "fireflies":
|
||||
classification_content = get_classification_content_from_fireflies_chunks(
|
||||
first_num_classification_chunks
|
||||
)
|
||||
|
||||
else:
|
||||
classification_content = (
|
||||
first_num_classification_chunks[0]["fields"]["title"]
|
||||
+ "\n"
|
||||
+ "\n".join(
|
||||
[
|
||||
chunk_content["fields"]["content"]
|
||||
for chunk_content in first_num_classification_chunks
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return classification_content
|
||||
@@ -126,13 +126,16 @@ def get_default_llm_with_vision(
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Try the default vision provider first
|
||||
default_provider = fetch_default_vision_provider(db_session)
|
||||
if default_provider and default_provider.default_vision_model:
|
||||
if model_supports_image_input(
|
||||
if (
|
||||
default_provider
|
||||
and default_provider.default_vision_model
|
||||
and model_supports_image_input(
|
||||
default_provider.default_vision_model, default_provider.provider
|
||||
):
|
||||
return create_vision_llm(
|
||||
default_provider, default_provider.default_vision_model
|
||||
)
|
||||
)
|
||||
):
|
||||
return create_vision_llm(
|
||||
default_provider, default_provider.default_vision_model
|
||||
)
|
||||
|
||||
# Fall back to searching all providers
|
||||
providers = fetch_existing_llm_providers(db_session)
|
||||
@@ -140,36 +143,14 @@ def get_default_llm_with_vision(
|
||||
if not providers:
|
||||
return None
|
||||
|
||||
# Check all providers for viable vision models
|
||||
# Find the first provider that supports image input
|
||||
for provider in providers:
|
||||
provider_view = LLMProviderView.from_model(provider)
|
||||
|
||||
# First priority: Check if provider has a default_vision_model
|
||||
if provider.default_vision_model and model_supports_image_input(
|
||||
provider.default_vision_model, provider.provider
|
||||
):
|
||||
return create_vision_llm(provider_view, provider.default_vision_model)
|
||||
|
||||
# If no model_names are specified, try default models in priority order
|
||||
if not provider.model_names:
|
||||
# Try default_model_name
|
||||
if provider.default_model_name and model_supports_image_input(
|
||||
provider.default_model_name, provider.provider
|
||||
):
|
||||
return create_vision_llm(provider_view, provider.default_model_name)
|
||||
|
||||
# Try fast_default_model_name
|
||||
if provider.fast_default_model_name and model_supports_image_input(
|
||||
provider.fast_default_model_name, provider.provider
|
||||
):
|
||||
return create_vision_llm(
|
||||
provider_view, provider.fast_default_model_name
|
||||
)
|
||||
else:
|
||||
# If model_names is specified, check each model
|
||||
for model_name in provider.model_names:
|
||||
if model_supports_image_input(model_name, provider.provider):
|
||||
return create_vision_llm(provider_view, model_name)
|
||||
return create_vision_llm(
|
||||
LLMProviderView.from_model(provider), provider.default_vision_model
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -246,75 +246,3 @@ Please give a short succinct summary of the entire document. Answer only with th
|
||||
summary and nothing else. """
|
||||
|
||||
DOCUMENT_SUMMARY_TOKEN_ESTIMATE = 29
|
||||
|
||||
|
||||
QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT = """
|
||||
Please rephrase the following user question/query as a semantic query that would be appropriate for a \
|
||||
search engine.
|
||||
|
||||
Note:
|
||||
- do not change the meaning of the question! Specifically, if the query is a an instruction, keep it \
|
||||
as an instruction!
|
||||
|
||||
Here is the user question/query:
|
||||
{question}
|
||||
|
||||
Respond with EXACTLY and ONLY one rephrased question/query.
|
||||
|
||||
Rephrased question/query for search engine:
|
||||
""".strip()
|
||||
|
||||
|
||||
QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT = """
|
||||
Following a previous message history, a user created a follow-up question/query.
|
||||
Please rephrase that question/query as a semantic query \
|
||||
that would be appropriate for a SEARCH ENGINE. Only use the information provided \
|
||||
from the history that is relevant to provide the relevant context for the search query, \
|
||||
meaning that the rephrased search query should be a suitable stand-alone search query.
|
||||
|
||||
Note:
|
||||
- do not change the meaning of the question! Specifically, if the query is a an instruction, keep it \
|
||||
as an instruction!
|
||||
|
||||
Here is the relevant previous message history:
|
||||
{history}
|
||||
|
||||
Here is the user question:
|
||||
{question}
|
||||
|
||||
Respond with EXACTLY and ONLY one rephrased query.
|
||||
|
||||
Rephrased query for search engine:
|
||||
""".strip()
|
||||
|
||||
|
||||
QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT = """
|
||||
Please rephrase the following user question as a keyword query that would be appropriate for a \
|
||||
search engine.
|
||||
|
||||
Here is the user question:
|
||||
{question}
|
||||
|
||||
Respond with EXACTLY and ONLY one rephrased query.
|
||||
|
||||
Rephrased query for search engine:
|
||||
""".strip()
|
||||
|
||||
|
||||
QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT = """
|
||||
Following a previous message history, a user created a follow-up question/query.
|
||||
Please rephrase that question/query as a keyword query \
|
||||
that would be appropriate for a SEARCH ENGINE. Only use the information provided \
|
||||
from the history that is relevant to provide the relevant context for the search query, \
|
||||
meaning that the rephrased search query should be a suitable stand-alone search query.
|
||||
|
||||
Here is the relevant previous message history:
|
||||
{history}
|
||||
|
||||
Here is the user question:
|
||||
{question}
|
||||
|
||||
Respond with EXACTLY and ONLY one rephrased query.
|
||||
|
||||
Rephrased query for search engine:
|
||||
""".strip()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -783,7 +783,6 @@ def get_connector_indexing_status(
|
||||
name=cc_pair.name,
|
||||
in_progress=in_progress,
|
||||
cc_pair_status=cc_pair.status,
|
||||
in_repeated_error_state=cc_pair.in_repeated_error_state,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
connector, connector_to_cc_pair_ids.get(connector.id, [])
|
||||
),
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
@@ -129,7 +128,6 @@ class CredentialBase(BaseModel):
|
||||
class CredentialSnapshot(CredentialBase):
|
||||
id: int
|
||||
user_id: UUID | None
|
||||
user_email: str | None = None
|
||||
time_created: datetime
|
||||
time_updated: datetime
|
||||
|
||||
@@ -143,7 +141,6 @@ class CredentialSnapshot(CredentialBase):
|
||||
else credential.credential_json
|
||||
),
|
||||
user_id=credential.user_id,
|
||||
user_email=credential.user.email if credential.user else None,
|
||||
admin_public=credential.admin_public,
|
||||
time_created=credential.time_created,
|
||||
time_updated=credential.time_updated,
|
||||
@@ -210,7 +207,6 @@ class CCPairFullInfo(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
status: ConnectorCredentialPairStatus
|
||||
in_repeated_error_state: bool
|
||||
num_docs_indexed: int
|
||||
connector: ConnectorSnapshot
|
||||
credential: CredentialSnapshot
|
||||
@@ -224,13 +220,6 @@ class CCPairFullInfo(BaseModel):
|
||||
creator: UUID | None
|
||||
creator_email: str | None
|
||||
|
||||
# information on syncing/indexing
|
||||
last_indexed: datetime | None
|
||||
last_pruned: datetime | None
|
||||
last_permission_sync: datetime | None
|
||||
overall_indexing_speed: float | None
|
||||
latest_checkpoint_description: str | None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls,
|
||||
@@ -248,8 +237,7 @@ class CCPairFullInfo(BaseModel):
|
||||
# there is a mismatch between these two numbers which may confuse users.
|
||||
last_indexing_status = last_index_attempt.status if last_index_attempt else None
|
||||
if (
|
||||
# only need to do this if the last indexing attempt is still in progress
|
||||
last_indexing_status == IndexingStatus.IN_PROGRESS
|
||||
last_indexing_status == IndexingStatus.SUCCESS
|
||||
and number_of_index_attempts == 1
|
||||
and last_index_attempt
|
||||
and last_index_attempt.new_docs_indexed
|
||||
@@ -258,18 +246,10 @@ class CCPairFullInfo(BaseModel):
|
||||
last_index_attempt.new_docs_indexed if last_index_attempt else 0
|
||||
)
|
||||
|
||||
overall_indexing_speed = num_docs_indexed / (
|
||||
(
|
||||
datetime.now(tz=timezone.utc) - cc_pair_model.connector.time_created
|
||||
).total_seconds()
|
||||
/ 60
|
||||
)
|
||||
|
||||
return cls(
|
||||
id=cc_pair_model.id,
|
||||
name=cc_pair_model.name,
|
||||
status=cc_pair_model.status,
|
||||
in_repeated_error_state=cc_pair_model.in_repeated_error_state,
|
||||
num_docs_indexed=num_docs_indexed,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair_model.connector
|
||||
@@ -288,15 +268,6 @@ class CCPairFullInfo(BaseModel):
|
||||
creator_email=(
|
||||
cc_pair_model.creator.email if cc_pair_model.creator else None
|
||||
),
|
||||
last_indexed=(
|
||||
last_index_attempt.time_started if last_index_attempt else None
|
||||
),
|
||||
last_pruned=cc_pair_model.last_pruned,
|
||||
last_permission_sync=(
|
||||
last_index_attempt.time_started if last_index_attempt else None
|
||||
),
|
||||
overall_indexing_speed=overall_indexing_speed,
|
||||
latest_checkpoint_description=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -337,9 +308,6 @@ class ConnectorIndexingStatus(ConnectorStatus):
|
||||
"""Represents the full indexing status of a connector"""
|
||||
|
||||
cc_pair_status: ConnectorCredentialPairStatus
|
||||
# this is separate from the `status` above, since a connector can be `INITIAL_INDEXING`, `ACTIVE`,
|
||||
# or `PAUSED` and still be in a repeated error state.
|
||||
in_repeated_error_state: bool
|
||||
owner: str
|
||||
last_finished_status: IndexingStatus | None
|
||||
last_status: IndexingStatus | None
|
||||
|
||||
@@ -118,13 +118,6 @@ class LLMProviderView(LLMProvider):
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
|
||||
# Safely get groups - handle detached instance case
|
||||
try:
|
||||
groups = [group.id for group in llm_provider_model.groups]
|
||||
except Exception:
|
||||
# If groups relationship can't be loaded (detached instance), use empty list
|
||||
groups = []
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
@@ -155,7 +148,7 @@ class LLMProviderView(LLMProvider):
|
||||
else None
|
||||
),
|
||||
is_public=llm_provider_model.is_public,
|
||||
groups=groups,
|
||||
groups=[group.id for group in llm_provider_model.groups],
|
||||
deployment_name=llm_provider_model.deployment_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import QueryExpansions
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
@@ -75,17 +74,11 @@ class SearchToolOverrideKwargs(BaseModel):
|
||||
precomputed_keywords: list[str] | None = None
|
||||
user_file_ids: list[int] | None = None
|
||||
user_folder_ids: list[int] | None = None
|
||||
document_sources: list[DocumentSource] | None = None
|
||||
time_cutoff: datetime | None = None
|
||||
ordering_only: bool | None = (
|
||||
None # Flag for fast path when search is only needed for ordering
|
||||
)
|
||||
document_sources: list[DocumentSource] | None = None
|
||||
time_cutoff: datetime | None = None
|
||||
expanded_queries: QueryExpansions | None = None
|
||||
kg_entities: list[str] | None = None
|
||||
kg_relationships: list[str] | None = None
|
||||
kg_terms: list[str] | None = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -295,7 +295,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
ordering_only = False
|
||||
document_sources = None
|
||||
time_cutoff = None
|
||||
expanded_queries = None
|
||||
if override_kwargs:
|
||||
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
|
||||
alternate_db_session = override_kwargs.alternate_db_session
|
||||
@@ -308,10 +307,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
ordering_only = use_alt_not_None(override_kwargs.ordering_only, False)
|
||||
document_sources = override_kwargs.document_sources
|
||||
time_cutoff = override_kwargs.time_cutoff
|
||||
expanded_queries = override_kwargs.expanded_queries
|
||||
kg_entities = override_kwargs.kg_entities
|
||||
kg_relationships = override_kwargs.kg_relationships
|
||||
kg_terms = override_kwargs.kg_terms
|
||||
|
||||
# Fast path for ordering-only search
|
||||
if ordering_only:
|
||||
@@ -361,16 +356,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
# Overwrite time-cutoff should supercede existing time-cutoff, even if defined
|
||||
retrieval_options.filters.time_cutoff = time_cutoff
|
||||
|
||||
# Initialize kg filters in retrieval options and filters with all provided values
|
||||
retrieval_options = retrieval_options or RetrievalDetails()
|
||||
retrieval_options.filters = retrieval_options.filters or BaseFilters()
|
||||
if kg_entities:
|
||||
retrieval_options.filters.kg_entities = kg_entities
|
||||
if kg_relationships:
|
||||
retrieval_options.filters.kg_relationships = kg_relationships
|
||||
if kg_terms:
|
||||
retrieval_options.filters.kg_terms = kg_terms
|
||||
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
@@ -406,8 +391,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
precomputed_query_embedding=precomputed_query_embedding,
|
||||
precomputed_is_keyword=precomputed_is_keyword,
|
||||
precomputed_keywords=precomputed_keywords,
|
||||
# add expanded queries
|
||||
expanded_queries=expanded_queries,
|
||||
),
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
|
||||
@@ -69,7 +69,7 @@ def get_log_level_from_str(log_level_str: str = LOG_LEVEL) -> int:
|
||||
"NOTSET": logging.NOTSET,
|
||||
}
|
||||
|
||||
return log_level_dict.get(log_level_str.upper(), logging.INFO)
|
||||
return log_level_dict.get(log_level_str.upper(), logging.getLevelName("NOTICE"))
|
||||
|
||||
|
||||
class OnyxRequestIDFilter(logging.Filter):
|
||||
|
||||
@@ -80,7 +80,6 @@ slack-sdk==3.20.2
|
||||
SQLAlchemy[mypy]==2.0.15
|
||||
starlette==0.46.1
|
||||
supervisor==4.2.5
|
||||
thefuzz==0.22.1
|
||||
tiktoken==0.7.0
|
||||
timeago==1.0.16
|
||||
transformers==4.49.0
|
||||
|
||||
@@ -47,7 +47,6 @@ from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
@@ -515,7 +514,6 @@ def get_number_of_chunks_we_think_exist(
|
||||
class VespaDebugging:
|
||||
# Class for managing Vespa debugging actions.
|
||||
def __init__(self, tenant_id: str = POSTGRES_DEFAULT_SCHEMA):
|
||||
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
self.tenant_id = tenant_id
|
||||
self.index_name = get_index_name(self.tenant_id)
|
||||
@@ -857,7 +855,6 @@ def delete_documents_for_tenant(
|
||||
|
||||
|
||||
def main() -> None:
|
||||
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||
parser = argparse.ArgumentParser(description="Vespa debugging tool")
|
||||
parser.add_argument(
|
||||
"--action",
|
||||
|
||||
@@ -5,7 +5,6 @@ RUN THIS AFTER SEED_DUMMY_DOCS.PY
|
||||
import random
|
||||
import time
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.model_configs import DOC_EMBEDDING_DIM
|
||||
from onyx.context.search.models import IndexFilters
|
||||
@@ -97,7 +96,6 @@ def test_hybrid_retrieval_times(
|
||||
hybrid_alpha=0.5,
|
||||
time_decay_multiplier=1.0,
|
||||
num_to_retrieve=50,
|
||||
ranking_profile_type=QueryExpansionType.SEMANTIC,
|
||||
offset=0,
|
||||
title_content_ratio=0.5,
|
||||
)
|
||||
|
||||
@@ -64,7 +64,7 @@ LOG_FILE_NAME = os.environ.get("LOG_FILE_NAME") or "onyx"
|
||||
# Enable generating persistent log files for local dev environments
|
||||
DEV_LOGGING_ENABLED = os.environ.get("DEV_LOGGING_ENABLED", "").lower() == "true"
|
||||
# notset, debug, info, notice, warning, error, or critical
|
||||
LOG_LEVEL = os.environ.get("LOG_LEVEL") or "info"
|
||||
LOG_LEVEL = os.environ.get("LOG_LEVEL", "info")
|
||||
|
||||
# Timeout for API-based embedding models
|
||||
# NOTE: does not apply for Google VertexAI, since the python client doesn't
|
||||
|
||||
@@ -69,7 +69,6 @@ class CCPairManager:
|
||||
connector_specific_config: dict[str, Any] | None = None,
|
||||
credential_json: dict[str, Any] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
refresh_freq: int | None = None,
|
||||
) -> DATestCCPair:
|
||||
connector = ConnectorManager.create(
|
||||
name=name,
|
||||
@@ -79,7 +78,6 @@ class CCPairManager:
|
||||
access_type=access_type,
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
refresh_freq=refresh_freq,
|
||||
)
|
||||
credential = CredentialManager.create(
|
||||
credential_json=credential_json,
|
||||
|
||||
@@ -78,19 +78,15 @@ class ChatSessionManager:
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
)
|
||||
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
cookies = user_performing_action.cookies if user_performing_action else None
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message",
|
||||
json=chat_message_req.model_dump(),
|
||||
headers=headers,
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
stream=True,
|
||||
cookies=cookies,
|
||||
)
|
||||
|
||||
return ChatSessionManager.analyze_response(response)
|
||||
|
||||
@@ -23,7 +23,6 @@ class ConnectorManager:
|
||||
access_type: AccessType = AccessType.PUBLIC,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
refresh_freq: int | None = None,
|
||||
) -> DATestConnector:
|
||||
name = f"{name}-connector" if name else f"test-connector-{uuid4()}"
|
||||
|
||||
@@ -37,7 +36,6 @@ class ConnectorManager:
|
||||
),
|
||||
access_type=access_type,
|
||||
groups=groups or [],
|
||||
refresh_freq=refresh_freq,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
|
||||
@@ -231,11 +231,6 @@ class DocumentManager:
|
||||
for doc_dict in retrieved_docs_dict:
|
||||
doc_id = doc_dict["fields"]["document_id"]
|
||||
doc_content = doc_dict["fields"]["content"]
|
||||
image_file_name = doc_dict["fields"].get("image_file_name", None)
|
||||
final_docs.append(
|
||||
SimpleTestDocument(
|
||||
id=doc_id, content=doc_content, image_file_name=image_file_name
|
||||
)
|
||||
)
|
||||
final_docs.append(SimpleTestDocument(id=doc_id, content=doc_content))
|
||||
|
||||
return final_docs
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import io
|
||||
import mimetypes
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
@@ -63,43 +62,3 @@ class FileManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
@staticmethod
|
||||
def upload_file_for_connector(
|
||||
file_path: str, file_name: str, user_performing_action: DATestUser
|
||||
) -> dict:
|
||||
# Read the file content
|
||||
with open(file_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
|
||||
# Create a file-like object
|
||||
file_obj = io.BytesIO(file_content)
|
||||
|
||||
# The 'files' form field expects a list of files
|
||||
files = [("files", (file_name, file_obj, "application/octet-stream"))]
|
||||
|
||||
# Use the user's headers but without Content-Type
|
||||
# as requests will set the correct multipart/form-data Content-Type for us
|
||||
headers = user_performing_action.headers.copy()
|
||||
if "Content-Type" in headers:
|
||||
del headers["Content-Type"]
|
||||
|
||||
# Make the request
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/file/upload",
|
||||
files=files,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
try:
|
||||
error_detail = response.json().get("detail", "Unknown error")
|
||||
except Exception:
|
||||
error_detail = response.text
|
||||
|
||||
raise Exception(
|
||||
f"Unable to upload files - {error_detail} (Status code: {response.status_code})"
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
return response_json
|
||||
|
||||
@@ -23,7 +23,7 @@ class SettingsManager:
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/settings",
|
||||
f"{API_SERVER_URL}/api/manage/admin/settings",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
@@ -48,8 +48,8 @@ class SettingsManager:
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
payload = settings.model_dump()
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/settings",
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/api/manage/admin/settings",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
Binary file not shown.
@@ -76,7 +76,6 @@ class DATestConnector(BaseModel):
|
||||
class SimpleTestDocument(BaseModel):
|
||||
id: str
|
||||
content: str
|
||||
image_file_name: str | None = None
|
||||
|
||||
|
||||
class DATestCCPair(BaseModel):
|
||||
@@ -178,8 +177,6 @@ class DATestSettings(BaseModel):
|
||||
gpu_enabled: bool | None = None
|
||||
product_gating: DATestGatingType = DATestGatingType.NONE
|
||||
anonymous_user_enabled: bool | None = None
|
||||
image_extraction_and_analysis_enabled: bool | None = False
|
||||
search_time_image_analysis_enabled: bool | None = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -6,8 +6,7 @@ INVITED_BASIC_USER = "basic_user"
|
||||
INVITED_BASIC_USER_EMAIL = "basic_user@test.com"
|
||||
|
||||
|
||||
def test_admin_can_invite_users(reset_multitenant: None) -> None:
|
||||
"""Test that an admin can invite both registered and non-registered users."""
|
||||
def test_user_invitation_flow(reset_multitenant: None) -> None:
|
||||
# Create first user (admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin")
|
||||
assert UserManager.is_role(admin_user, UserRole.ADMIN)
|
||||
@@ -20,44 +19,16 @@ def test_admin_can_invite_users(reset_multitenant: None) -> None:
|
||||
UserManager.invite_user(invited_user.email, admin_user)
|
||||
UserManager.invite_user(INVITED_BASIC_USER_EMAIL, admin_user)
|
||||
|
||||
# Verify users are in the invited users list
|
||||
invited_users = UserManager.get_invited_users(admin_user)
|
||||
assert invited_user.email in [
|
||||
user.email for user in invited_users
|
||||
], f"User {invited_user.email} not found in invited users list"
|
||||
|
||||
|
||||
def test_non_registered_user_gets_basic_role(reset_multitenant: None) -> None:
|
||||
"""Test that a non-registered user gets a BASIC role when they register after being invited."""
|
||||
# Create admin user
|
||||
admin_user: DATestUser = UserManager.create(name="admin")
|
||||
assert UserManager.is_role(admin_user, UserRole.ADMIN)
|
||||
|
||||
# Admin user invites a non-registered user
|
||||
UserManager.invite_user(INVITED_BASIC_USER_EMAIL, admin_user)
|
||||
|
||||
# Non-registered user registers
|
||||
invited_basic_user: DATestUser = UserManager.create(
|
||||
name=INVITED_BASIC_USER, email=INVITED_BASIC_USER_EMAIL
|
||||
)
|
||||
assert UserManager.is_role(invited_basic_user, UserRole.BASIC)
|
||||
|
||||
|
||||
def test_user_can_accept_invitation(reset_multitenant: None) -> None:
|
||||
"""Test that a user can accept an invitation and join the organization with BASIC role."""
|
||||
# Create admin user
|
||||
admin_user: DATestUser = UserManager.create(name="admin")
|
||||
assert UserManager.is_role(admin_user, UserRole.ADMIN)
|
||||
|
||||
# Create a user to be invited
|
||||
invited_user_email = "invited_user@test.com"
|
||||
|
||||
# User registers with the same email as the invitation
|
||||
invited_user: DATestUser = UserManager.create(
|
||||
name="invited_user", email=invited_user_email
|
||||
)
|
||||
# Admin user invites the user
|
||||
UserManager.invite_user(invited_user_email, admin_user)
|
||||
# Verify the user is in the invited users list
|
||||
invited_users = UserManager.get_invited_users(admin_user)
|
||||
assert invited_user.email in [
|
||||
user.email for user in invited_users
|
||||
], f"User {invited_user.email} not found in invited users list"
|
||||
|
||||
# Get user info to check tenant information
|
||||
user_info = UserManager.get_user_info(invited_user)
|
||||
@@ -70,17 +41,16 @@ def test_user_can_accept_invitation(reset_multitenant: None) -> None:
|
||||
)
|
||||
assert invited_tenant_id is not None, "Expected to find an invitation tenant_id"
|
||||
|
||||
# User accepts invitation
|
||||
UserManager.accept_invitation(invited_tenant_id, invited_user)
|
||||
|
||||
# User needs to reauthenticate after accepting invitation
|
||||
# Simulate this by creating a new user instance with the same credentials
|
||||
authenticated_user: DATestUser = UserManager.create(
|
||||
name="invited_user", email=invited_user_email
|
||||
)
|
||||
# Get updated user info after accepting invitation
|
||||
updated_user_info = UserManager.get_user_info(invited_user)
|
||||
|
||||
# Get updated user info after accepting invitation and reauthenticating
|
||||
updated_user_info = UserManager.get_user_info(authenticated_user)
|
||||
# Verify the user is no longer in the invited users list
|
||||
updated_invited_users = UserManager.get_invited_users(admin_user)
|
||||
assert invited_user.email not in [
|
||||
user.email for user in updated_invited_users
|
||||
], f"User {invited_user.email} should not be in invited users list after accepting"
|
||||
|
||||
# Verify the user has BASIC role in the organization
|
||||
assert (
|
||||
@@ -94,7 +64,7 @@ def test_user_can_accept_invitation(reset_multitenant: None) -> None:
|
||||
|
||||
# Check if the invited user is in the list of users with BASIC role
|
||||
invited_user_emails = [user.email for user in user_page.items]
|
||||
assert invited_user_email in invited_user_emails, (
|
||||
f"User {invited_user_email} not found in the list of basic users "
|
||||
assert invited_user.email in invited_user_emails, (
|
||||
f"User {invited_user.email} not found in the list of basic users "
|
||||
f"in the organization. Available users: {invited_user_emails}"
|
||||
)
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.db.models import UserRole
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
@@ -13,12 +11,12 @@ from tests.integration.common_utils.test_models import DATestChatSession
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def setup_test_tenants(reset_multitenant: None) -> dict[str, Any]:
|
||||
"""Helper function to set up test tenants with documents and users."""
|
||||
# Creating an admin user for Tenant 1
|
||||
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin and also proviions the tenant
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email="admin@onyx-test.com",
|
||||
)
|
||||
|
||||
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
|
||||
|
||||
# Create Tenant 2 and its Admin User
|
||||
@@ -37,16 +35,6 @@ def setup_test_tenants(reset_multitenant: None) -> dict[str, Any]:
|
||||
api_key_1.headers.update(admin_user1.headers)
|
||||
LLMProviderManager.create(user_performing_action=admin_user1)
|
||||
|
||||
# Create connectors for Tenant 2
|
||||
cc_pair_2: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user2,
|
||||
)
|
||||
api_key_2: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user2,
|
||||
)
|
||||
api_key_2.headers.update(admin_user2.headers)
|
||||
LLMProviderManager.create(user_performing_action=admin_user2)
|
||||
|
||||
# Seed documents for Tenant 1
|
||||
cc_pair_1.documents = []
|
||||
doc1_tenant1 = DocumentManager.seed_doc_with_content(
|
||||
@@ -61,6 +49,16 @@ def setup_test_tenants(reset_multitenant: None) -> dict[str, Any]:
|
||||
)
|
||||
cc_pair_1.documents.extend([doc1_tenant1, doc2_tenant1])
|
||||
|
||||
# Create connectors for Tenant 2
|
||||
cc_pair_2: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user2,
|
||||
)
|
||||
api_key_2: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user2,
|
||||
)
|
||||
api_key_2.headers.update(admin_user2.headers)
|
||||
LLMProviderManager.create(user_performing_action=admin_user2)
|
||||
|
||||
# Seed documents for Tenant 2
|
||||
cc_pair_2.documents = []
|
||||
doc1_tenant2 = DocumentManager.seed_doc_with_content(
|
||||
@@ -86,36 +84,21 @@ def setup_test_tenants(reset_multitenant: None) -> dict[str, Any]:
|
||||
user_performing_action=admin_user2
|
||||
)
|
||||
|
||||
return {
|
||||
"admin_user1": admin_user1,
|
||||
"admin_user2": admin_user2,
|
||||
"chat_session1": chat_session1,
|
||||
"chat_session2": chat_session2,
|
||||
"tenant1_doc_ids": tenant1_doc_ids,
|
||||
"tenant2_doc_ids": tenant2_doc_ids,
|
||||
}
|
||||
|
||||
|
||||
def test_tenant1_can_access_own_documents(reset_multitenant: None) -> None:
|
||||
"""Test that Tenant 1 can access its own documents but not Tenant 2's."""
|
||||
test_data = setup_test_tenants(reset_multitenant)
|
||||
|
||||
# User 1 sends a message and gets a response
|
||||
response1 = ChatSessionManager.send_message(
|
||||
chat_session_id=test_data["chat_session1"].id,
|
||||
chat_session_id=chat_session1.id,
|
||||
message="What is in Tenant 1's documents?",
|
||||
user_performing_action=test_data["admin_user1"],
|
||||
user_performing_action=admin_user1,
|
||||
)
|
||||
|
||||
# Assert that the search tool was used
|
||||
assert response1.tool_name == "run_search"
|
||||
|
||||
response_doc_ids = {doc["document_id"] for doc in response1.tool_result or []}
|
||||
assert test_data["tenant1_doc_ids"].issubset(
|
||||
assert tenant1_doc_ids.issubset(
|
||||
response_doc_ids
|
||||
), "Not all Tenant 1 document IDs are in the response"
|
||||
assert not response_doc_ids.intersection(
|
||||
test_data["tenant2_doc_ids"]
|
||||
tenant2_doc_ids
|
||||
), "Tenant 2 document IDs should not be in the response"
|
||||
|
||||
# Assert that the contents are correct
|
||||
@@ -124,28 +107,21 @@ def test_tenant1_can_access_own_documents(reset_multitenant: None) -> None:
|
||||
for doc in response1.tool_result or []
|
||||
), "Tenant 1 Document Content not found in any document"
|
||||
|
||||
|
||||
def test_tenant2_can_access_own_documents(reset_multitenant: None) -> None:
|
||||
"""Test that Tenant 2 can access its own documents but not Tenant 1's."""
|
||||
test_data = setup_test_tenants(reset_multitenant)
|
||||
|
||||
# User 2 sends a message and gets a response
|
||||
response2 = ChatSessionManager.send_message(
|
||||
chat_session_id=test_data["chat_session2"].id,
|
||||
chat_session_id=chat_session2.id,
|
||||
message="What is in Tenant 2's documents?",
|
||||
user_performing_action=test_data["admin_user2"],
|
||||
user_performing_action=admin_user2,
|
||||
)
|
||||
|
||||
# Assert that the search tool was used
|
||||
assert response2.tool_name == "run_search"
|
||||
|
||||
# Assert that the tool_result contains Tenant 2's documents
|
||||
response_doc_ids = {doc["document_id"] for doc in response2.tool_result or []}
|
||||
assert test_data["tenant2_doc_ids"].issubset(
|
||||
assert tenant2_doc_ids.issubset(
|
||||
response_doc_ids
|
||||
), "Not all Tenant 2 document IDs are in the response"
|
||||
assert not response_doc_ids.intersection(
|
||||
test_data["tenant1_doc_ids"]
|
||||
tenant1_doc_ids
|
||||
), "Tenant 1 document IDs should not be in the response"
|
||||
|
||||
# Assert that the contents are correct
|
||||
@@ -154,91 +130,28 @@ def test_tenant2_can_access_own_documents(reset_multitenant: None) -> None:
|
||||
for doc in response2.tool_result or []
|
||||
), "Tenant 2 Document Content not found in any document"
|
||||
|
||||
|
||||
def test_tenant1_cannot_access_tenant2_documents(reset_multitenant: None) -> None:
|
||||
"""Test that Tenant 1 cannot access Tenant 2's documents."""
|
||||
test_data = setup_test_tenants(reset_multitenant)
|
||||
|
||||
# User 1 tries to access Tenant 2's documents
|
||||
response_cross = ChatSessionManager.send_message(
|
||||
chat_session_id=test_data["chat_session1"].id,
|
||||
chat_session_id=chat_session1.id,
|
||||
message="What is in Tenant 2's documents?",
|
||||
user_performing_action=test_data["admin_user1"],
|
||||
user_performing_action=admin_user1,
|
||||
)
|
||||
|
||||
# Assert that the search tool was used
|
||||
assert response_cross.tool_name == "run_search"
|
||||
|
||||
# Assert that the tool_result is empty or does not contain Tenant 2's documents
|
||||
response_doc_ids = {doc["document_id"] for doc in response_cross.tool_result or []}
|
||||
|
||||
# Ensure none of Tenant 2's document IDs are in the response
|
||||
assert not response_doc_ids.intersection(test_data["tenant2_doc_ids"])
|
||||
|
||||
|
||||
def test_tenant2_cannot_access_tenant1_documents(reset_multitenant: None) -> None:
|
||||
"""Test that Tenant 2 cannot access Tenant 1's documents."""
|
||||
test_data = setup_test_tenants(reset_multitenant)
|
||||
assert not response_doc_ids.intersection(tenant2_doc_ids)
|
||||
|
||||
# User 2 tries to access Tenant 1's documents
|
||||
response_cross2 = ChatSessionManager.send_message(
|
||||
chat_session_id=test_data["chat_session2"].id,
|
||||
chat_session_id=chat_session2.id,
|
||||
message="What is in Tenant 1's documents?",
|
||||
user_performing_action=test_data["admin_user2"],
|
||||
user_performing_action=admin_user2,
|
||||
)
|
||||
|
||||
# Assert that the search tool was used
|
||||
assert response_cross2.tool_name == "run_search"
|
||||
|
||||
# Assert that the tool_result is empty or does not contain Tenant 1's documents
|
||||
response_doc_ids = {doc["document_id"] for doc in response_cross2.tool_result or []}
|
||||
|
||||
# Ensure none of Tenant 1's document IDs are in the response
|
||||
assert not response_doc_ids.intersection(test_data["tenant1_doc_ids"])
|
||||
|
||||
|
||||
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
"""Legacy test for multi-tenant access control."""
|
||||
test_data = setup_test_tenants(reset_multitenant)
|
||||
|
||||
# User 1 sends a message and gets a response with only Tenant 1's documents
|
||||
response1 = ChatSessionManager.send_message(
|
||||
chat_session_id=test_data["chat_session1"].id,
|
||||
message="What is in Tenant 1's documents?",
|
||||
user_performing_action=test_data["admin_user1"],
|
||||
)
|
||||
assert response1.tool_name == "run_search"
|
||||
response_doc_ids = {doc["document_id"] for doc in response1.tool_result or []}
|
||||
assert test_data["tenant1_doc_ids"].issubset(response_doc_ids)
|
||||
assert not response_doc_ids.intersection(test_data["tenant2_doc_ids"])
|
||||
|
||||
# User 2 sends a message and gets a response with only Tenant 2's documents
|
||||
response2 = ChatSessionManager.send_message(
|
||||
chat_session_id=test_data["chat_session2"].id,
|
||||
message="What is in Tenant 2's documents?",
|
||||
user_performing_action=test_data["admin_user2"],
|
||||
)
|
||||
assert response2.tool_name == "run_search"
|
||||
response_doc_ids = {doc["document_id"] for doc in response2.tool_result or []}
|
||||
assert test_data["tenant2_doc_ids"].issubset(response_doc_ids)
|
||||
assert not response_doc_ids.intersection(test_data["tenant1_doc_ids"])
|
||||
|
||||
# User 1 tries to access Tenant 2's documents and fails
|
||||
response_cross = ChatSessionManager.send_message(
|
||||
chat_session_id=test_data["chat_session1"].id,
|
||||
message="What is in Tenant 2's documents?",
|
||||
user_performing_action=test_data["admin_user1"],
|
||||
)
|
||||
assert response_cross.tool_name == "run_search"
|
||||
response_doc_ids = {doc["document_id"] for doc in response_cross.tool_result or []}
|
||||
assert not response_doc_ids.intersection(test_data["tenant2_doc_ids"])
|
||||
|
||||
# User 2 tries to access Tenant 1's documents and fails
|
||||
response_cross2 = ChatSessionManager.send_message(
|
||||
chat_session_id=test_data["chat_session2"].id,
|
||||
message="What is in Tenant 1's documents?",
|
||||
user_performing_action=test_data["admin_user2"],
|
||||
)
|
||||
assert response_cross2.tool_name == "run_search"
|
||||
response_doc_ids = {doc["document_id"] for doc in response_cross2.tool_result or []}
|
||||
assert not response_doc_ids.intersection(test_data["tenant1_doc_ids"])
|
||||
assert not response_doc_ids.intersection(tenant1_doc_ids)
|
||||
|
||||
@@ -8,51 +8,12 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_first_user_is_admin(reset_multitenant: None) -> None:
|
||||
"""Test that the first user of a tenant is automatically assigned ADMIN role."""
|
||||
# Test flow from creating tenant to registering as a user
|
||||
def test_tenant_creation(reset_multitenant: None) -> None:
|
||||
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
|
||||
|
||||
assert UserManager.is_role(test_user, UserRole.ADMIN)
|
||||
|
||||
|
||||
def test_admin_can_create_credential(reset_multitenant: None) -> None:
|
||||
"""Test that an admin user can create a credential in their tenant."""
|
||||
# Create admin user
|
||||
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
|
||||
assert UserManager.is_role(test_user, UserRole.ADMIN)
|
||||
|
||||
# Create credential
|
||||
test_credential = CredentialManager.create(
|
||||
name="admin_test_credential",
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=False,
|
||||
user_performing_action=test_user,
|
||||
)
|
||||
assert test_credential is not None
|
||||
|
||||
|
||||
def test_admin_can_create_connector(reset_multitenant: None) -> None:
|
||||
"""Test that an admin user can create a connector in their tenant."""
|
||||
# Create admin user
|
||||
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
|
||||
assert UserManager.is_role(test_user, UserRole.ADMIN)
|
||||
|
||||
# Create connector
|
||||
test_connector = ConnectorManager.create(
|
||||
name="admin_test_connector",
|
||||
source=DocumentSource.FILE,
|
||||
access_type=AccessType.PRIVATE,
|
||||
user_performing_action=test_user,
|
||||
)
|
||||
assert test_connector is not None
|
||||
|
||||
|
||||
def test_admin_can_create_and_verify_cc_pair(reset_multitenant: None) -> None:
|
||||
"""Test that an admin user can create and verify a connector-credential pair in their tenant."""
|
||||
# Create admin user
|
||||
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
|
||||
assert UserManager.is_role(test_user, UserRole.ADMIN)
|
||||
|
||||
# Create credential
|
||||
test_credential = CredentialManager.create(
|
||||
name="admin_test_credential",
|
||||
source=DocumentSource.FILE,
|
||||
@@ -60,7 +21,6 @@ def test_admin_can_create_and_verify_cc_pair(reset_multitenant: None) -> None:
|
||||
user_performing_action=test_user,
|
||||
)
|
||||
|
||||
# Create connector
|
||||
test_connector = ConnectorManager.create(
|
||||
name="admin_test_connector",
|
||||
source=DocumentSource.FILE,
|
||||
@@ -68,7 +28,6 @@ def test_admin_can_create_and_verify_cc_pair(reset_multitenant: None) -> None:
|
||||
user_performing_action=test_user,
|
||||
)
|
||||
|
||||
# Create cc_pair
|
||||
test_cc_pair = CCPairManager.create(
|
||||
connector_id=test_connector.id,
|
||||
credential_id=test_credential.id,
|
||||
@@ -76,7 +35,5 @@ def test_admin_can_create_and_verify_cc_pair(reset_multitenant: None) -> None:
|
||||
access_type=AccessType.PRIVATE,
|
||||
user_performing_action=test_user,
|
||||
)
|
||||
assert test_cc_pair is not None
|
||||
|
||||
# Verify cc_pair
|
||||
CCPairManager.verify(cc_pair=test_cc_pair, user_performing_action=test_user)
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.file import FileManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.settings import SettingsManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestSettings
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
FILE_NAME = "Sample.pdf"
|
||||
FILE_PATH = "tests/integration/common_utils/test_files"
|
||||
|
||||
|
||||
def test_image_indexing(
|
||||
reset: None,
|
||||
vespa_client: vespa_fixture,
|
||||
) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(
|
||||
email="admin@onyx-test.com",
|
||||
)
|
||||
|
||||
os.makedirs(FILE_PATH, exist_ok=True)
|
||||
test_file_path = os.path.join(FILE_PATH, FILE_NAME)
|
||||
|
||||
# Use FileManager to upload the test file
|
||||
upload_response = FileManager.upload_file_for_connector(
|
||||
file_path=test_file_path, file_name=FILE_NAME, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
LLMProviderManager.create(
|
||||
name="test_llm",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
SettingsManager.update_settings(
|
||||
DATestSettings(
|
||||
search_time_image_analysis_enabled=True,
|
||||
image_extraction_and_analysis_enabled=True,
|
||||
),
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
file_paths = upload_response.get("file_paths", [])
|
||||
|
||||
if not file_paths:
|
||||
pytest.fail("File upload failed - no file paths returned")
|
||||
|
||||
# Create a dummy credential for the file connector
|
||||
credential = CredentialManager.create(
|
||||
source=DocumentSource.FILE,
|
||||
credential_json={},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create the connector
|
||||
connector_name = f"FileConnector-{int(datetime.now().timestamp())}"
|
||||
connector = ConnectorManager.create(
|
||||
name=connector_name,
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={"file_locations": file_paths},
|
||||
access_type=AccessType.PUBLIC,
|
||||
groups=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Link the credential to the connector
|
||||
cc_pair = CCPairManager.create(
|
||||
credential_id=credential.id,
|
||||
connector_id=connector.id,
|
||||
access_type=AccessType.PUBLIC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Explicitly run the connector to start indexing
|
||||
CCPairManager.run_once(
|
||||
cc_pair=cc_pair,
|
||||
from_beginning=True,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=datetime.now(timezone.utc),
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
|
||||
# Ensure we indexed an image from the sample.pdf file
|
||||
has_sample_pdf_image = False
|
||||
for doc in documents:
|
||||
if doc.image_file_name and FILE_NAME in doc.image_file_name:
|
||||
has_sample_pdf_image = True
|
||||
|
||||
# Assert that at least one document has an image file name containing "sample.pdf"
|
||||
assert (
|
||||
has_sample_pdf_image
|
||||
), "No document found with an image file name containing 'sample.pdf'"
|
||||
@@ -1,18 +0,0 @@
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_HOST
|
||||
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_PORT
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_server_client() -> httpx.Client:
|
||||
print(
|
||||
f"Initializing mock server client with host: "
|
||||
f"{MOCK_CONNECTOR_SERVER_HOST} and port: "
|
||||
f"{MOCK_CONNECTOR_SERVER_PORT}"
|
||||
)
|
||||
return httpx.Client(
|
||||
base_url=f"http://{MOCK_CONNECTOR_SERVER_HOST}:{MOCK_CONNECTOR_SERVER_PORT}",
|
||||
timeout=5.0,
|
||||
)
|
||||
@@ -4,6 +4,7 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.mock_connector.connector import MockConnectorCheckpoint
|
||||
@@ -25,6 +26,19 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_server_client() -> httpx.Client:
|
||||
print(
|
||||
f"Initializing mock server client with host: "
|
||||
f"{MOCK_CONNECTOR_SERVER_HOST} and port: "
|
||||
f"{MOCK_CONNECTOR_SERVER_PORT}"
|
||||
)
|
||||
return httpx.Client(
|
||||
base_url=f"http://{MOCK_CONNECTOR_SERVER_HOST}:{MOCK_CONNECTOR_SERVER_PORT}",
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
|
||||
def test_mock_connector_basic_flow(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.background.celery.tasks.indexing.utils import (
|
||||
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.mock_connector.connector import MockConnectorCheckpoint
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_HOST
|
||||
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_PORT
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
|
||||
from tests.integration.common_utils.test_document_utils import create_test_document
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
def test_repeated_error_state_detection_and_recovery(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that a connector is marked as in a repeated error state after
|
||||
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE consecutive failures, and
|
||||
that it recovers after a successful indexing.
|
||||
|
||||
This test ensures we properly wait for the required number of indexing attempts
|
||||
to fail before checking that the connector is in a repeated error state."""
|
||||
|
||||
# Create test document for successful response
|
||||
test_doc = create_test_document()
|
||||
|
||||
# First, set up the mock server to consistently fail
|
||||
error_response = {
|
||||
"documents": [],
|
||||
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
"unhandled_exception": "Simulated unhandled error for testing repeated errors",
|
||||
}
|
||||
|
||||
# Create a list of failure responses with at least the same length
|
||||
# as NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE
|
||||
failure_behaviors = [error_response] * (
|
||||
5 * NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE
|
||||
)
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=failure_behaviors,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create a new CC pair for testing
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-repeated-error-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
refresh_freq=60 * 60, # a very long time
|
||||
)
|
||||
|
||||
# Wait for the required number of failed indexing attempts
|
||||
# This shouldn't take long, since we keep retrying while we haven't
|
||||
# succeeded yet
|
||||
start_time = time.monotonic()
|
||||
while True:
|
||||
index_attempts_page = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id=cc_pair.id,
|
||||
page=0,
|
||||
page_size=100,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
index_attempts = [
|
||||
ia
|
||||
for ia in index_attempts_page.items
|
||||
if ia.status and ia.status.is_terminal()
|
||||
]
|
||||
if len(index_attempts) == NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE:
|
||||
break
|
||||
|
||||
if time.monotonic() - start_time > 180:
|
||||
raise TimeoutError(
|
||||
"Did not get required number of failed attempts within 180 seconds"
|
||||
)
|
||||
|
||||
# make sure that we don't mark the connector as in repeated error state
|
||||
# before we have the required number of failed attempts
|
||||
with get_session_context_manager() as db_session:
|
||||
cc_pair_obj = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
assert cc_pair_obj is not None
|
||||
assert not cc_pair_obj.in_repeated_error_state
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
# Verify we have the correct number of failed attempts
|
||||
assert len(index_attempts) == NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE
|
||||
for attempt in index_attempts:
|
||||
assert attempt.status == IndexingStatus.FAILED
|
||||
|
||||
# Check if the connector is in a repeated error state
|
||||
start_time = time.monotonic()
|
||||
while True:
|
||||
with get_session_context_manager() as db_session:
|
||||
cc_pair_obj = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
assert cc_pair_obj is not None
|
||||
if cc_pair_obj.in_repeated_error_state:
|
||||
break
|
||||
|
||||
if time.monotonic() - start_time > 30:
|
||||
assert False, "CC pair did not enter repeated error state within 30 seconds"
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
# Reset the mock server state
|
||||
response = mock_server_client.post("/reset")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Now set up the mock server to succeed
|
||||
success_response = {
|
||||
"documents": [test_doc.model_dump(mode="json")],
|
||||
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(mode="json"),
|
||||
"failures": [],
|
||||
}
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[success_response],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Run another indexing attempt that should succeed
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=True, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts_to_ignore=[index_attempt.id for index_attempt in index_attempts],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=recovery_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Validate the indexing succeeded
|
||||
finished_recovery_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=recovery_index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_recovery_attempt.status == IndexingStatus.SUCCESS
|
||||
|
||||
# Verify the document was indexed
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id == test_doc.id
|
||||
|
||||
# Verify the CC pair is no longer in a repeated error state
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
with get_session_context_manager() as db_session:
|
||||
cc_pair_obj = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
assert cc_pair_obj is not None
|
||||
if not cc_pair_obj.in_repeated_error_state:
|
||||
break
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed > 30:
|
||||
raise TimeoutError(
|
||||
"CC pair did not exit repeated error state within 30 seconds"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Waiting for CC pair to exit repeated error state. elapsed={elapsed:.2f}"
|
||||
)
|
||||
time.sleep(1)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user