mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-01 13:45:44 +00:00
Compare commits
6 Commits
v2.3.0-bet
...
KG_dev_cop
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0e38898c7a | ||
|
|
ce6a597eca | ||
|
|
d251ba40ae | ||
|
|
26395d81c9 | ||
|
|
e1a3e11ec9 | ||
|
|
e013711664 |
4
.github/workflows/pr-integration-tests.yml
vendored
4
.github/workflows/pr-integration-tests.yml
vendored
@@ -147,6 +147,8 @@ jobs:
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
@@ -245,6 +247,8 @@ jobs:
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
|
||||
@@ -183,6 +183,8 @@ jobs:
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
|
||||
@@ -0,0 +1,690 @@
|
||||
"""create knowledge graph tables
|
||||
|
||||
Revision ID: 495cb26ce93e
|
||||
Revises: ca04500b9ee8
|
||||
Create Date: 2025-03-19 08:51:14.341989
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "495cb26ce93e"
|
||||
down_revision = "ca04500b9ee8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
# Create a new permission-less user to be later used for knowledge graph queries.
|
||||
# The user will later get temporary read privileges for a specific view that will be
|
||||
# ad hoc generated specific to a knowledge graph query.
|
||||
#
|
||||
# Note: in order for the migration to run, the DB_READONLY_USER and DB_READONLY_PASSWORD
|
||||
# environment variables MUST be set. Otherwise, an exception will be raised.
|
||||
|
||||
print("MULTI_TENANT: ", MULTI_TENANT)
|
||||
if not MULTI_TENANT:
|
||||
|
||||
print("Single tenant mode")
|
||||
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
# Create read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is created in the alembic_tenants migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Grant usage on current schema to readonly user
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
EXECUTE format('GRANT USAGE ON SCHEMA %I TO %I', current_schema(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"kg_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("kg_variable_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("kg_variable_values", postgresql.ARRAY(sa.String()), nullable=False),
|
||||
sa.UniqueConstraint("kg_variable_name", name="uq_kg_config_variable_name"),
|
||||
)
|
||||
|
||||
# Insert initial data into kg_config table
|
||||
op.bulk_insert(
|
||||
sa.table(
|
||||
"kg_config",
|
||||
sa.column("kg_variable_name", sa.String),
|
||||
sa.column("kg_variable_values", postgresql.ARRAY(sa.String)),
|
||||
),
|
||||
[
|
||||
{"kg_variable_name": "KG_EXPOSED", "kg_variable_values": ["false"]},
|
||||
{"kg_variable_name": "KG_ENABLED", "kg_variable_values": ["false"]},
|
||||
{"kg_variable_name": "KG_VENDOR", "kg_variable_values": []},
|
||||
{"kg_variable_name": "KG_VENDOR_DOMAINS", "kg_variable_values": []},
|
||||
{"kg_variable_name": "KG_IGNORE_EMAIL_DOMAINS", "kg_variable_values": []},
|
||||
{
|
||||
"kg_variable_name": "KG_EXTRACTION_IN_PROGRESS",
|
||||
"kg_variable_values": ["false"],
|
||||
},
|
||||
{
|
||||
"kg_variable_name": "KG_CLUSTERING_IN_PROGRESS",
|
||||
"kg_variable_values": ["false"],
|
||||
},
|
||||
{
|
||||
"kg_variable_name": "KG_COVERAGE_START",
|
||||
"kg_variable_values": [
|
||||
(datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d")
|
||||
],
|
||||
},
|
||||
{"kg_variable_name": "KG_MAX_COVERAGE_DAYS", "kg_variable_values": ["90"]},
|
||||
{
|
||||
"kg_variable_name": "KG_MAX_PARENT_RECURSION_DEPTH",
|
||||
"kg_variable_values": ["2"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"kg_entity_type",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column("grounding", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"attributes",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column("deep_extraction", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.Column("grounded_source_name", sa.String(), nullable=True),
|
||||
sa.Column("entity_values", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column(
|
||||
"clustering",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
)
|
||||
|
||||
# Create KGRelationshipType table
|
||||
op.create_table(
|
||||
"kg_relationship_type",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column(
|
||||
"source_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column(
|
||||
"target_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.Column(
|
||||
"clustering",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
)
|
||||
|
||||
# Create KGRelationshipTypeExtractionStaging table
|
||||
op.create_table(
|
||||
"kg_relationship_type_extraction_staging",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column(
|
||||
"source_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column(
|
||||
"target_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=True),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.Column(
|
||||
"clustering",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("transferred", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
)
|
||||
|
||||
# Create KGEntity table
|
||||
op.create_table(
|
||||
"kg_entity",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("entity_class", sa.String(), nullable=True, index=True),
|
||||
sa.Column("entity_subtype", sa.String(), nullable=True, index=True),
|
||||
sa.Column("entity_key", sa.String(), nullable=True, index=True),
|
||||
sa.Column("name_trigrams", postgresql.ARRAY(sa.String(3)), nullable=True),
|
||||
sa.Column("document_id", sa.String(), nullable=True, index=True),
|
||||
sa.Column(
|
||||
"alternative_names",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"keywords",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column(
|
||||
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
|
||||
),
|
||||
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("attributes", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["document_id"], ["document.id"]),
|
||||
sa.UniqueConstraint(
|
||||
"name",
|
||||
"entity_type_id_name",
|
||||
"document_id",
|
||||
name="uq_kg_entity_name_type_doc",
|
||||
),
|
||||
)
|
||||
op.create_index("ix_entity_type_acl", "kg_entity", ["entity_type_id_name", "acl"])
|
||||
op.create_index(
|
||||
"ix_entity_name_search", "kg_entity", ["name", "entity_type_id_name"]
|
||||
)
|
||||
|
||||
# Create KGEntityExtractionStaging table
|
||||
op.create_table(
|
||||
"kg_entity_extraction_staging",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("document_id", sa.String(), nullable=True, index=True),
|
||||
sa.Column(
|
||||
"alternative_names",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"keywords",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column(
|
||||
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
|
||||
),
|
||||
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("attributes", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("transferred_id_name", sa.String(), nullable=True, default=None),
|
||||
sa.Column("entity_class", sa.String(), nullable=True, index=True),
|
||||
sa.Column("entity_key", sa.String(), nullable=True, index=True),
|
||||
sa.Column("entity_subtype", sa.String(), nullable=True, index=True),
|
||||
sa.Column("parent_key", sa.String(), nullable=True, index=True),
|
||||
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["document_id"], ["document.id"]),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_entity_extraction_staging_acl",
|
||||
"kg_entity_extraction_staging",
|
||||
["entity_type_id_name", "acl"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_entity_extraction_staging_name_search",
|
||||
"kg_entity_extraction_staging",
|
||||
["name", "entity_type_id_name"],
|
||||
)
|
||||
|
||||
# Create KGRelationship table
|
||||
op.create_table(
|
||||
"kg_relationship",
|
||||
sa.Column("id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_document", sa.String(), nullable=True, index=True),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["source_node"], ["kg_entity.id_name"]),
|
||||
sa.ForeignKeyConstraint(["target_node"], ["kg_entity.id_name"]),
|
||||
sa.ForeignKeyConstraint(["source_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["target_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["source_document"], ["document.id"]),
|
||||
sa.ForeignKeyConstraint(
|
||||
["relationship_type_id_name"], ["kg_relationship_type.id_name"]
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_source_target_type",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id_name", "source_document"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kg_relationship_nodes", "kg_relationship", ["source_node", "target_node"]
|
||||
)
|
||||
|
||||
# Create KGRelationshipExtractionStaging table
|
||||
op.create_table(
|
||||
"kg_relationship_extraction_staging",
|
||||
sa.Column("id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_document", sa.String(), nullable=True, index=True),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
|
||||
sa.Column("transferred", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_node"], ["kg_entity_extraction_staging.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["target_node"], ["kg_entity_extraction_staging.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(["source_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["target_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["source_document"], ["document.id"]),
|
||||
sa.ForeignKeyConstraint(
|
||||
["relationship_type_id_name"],
|
||||
["kg_relationship_type_extraction_staging.id_name"],
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_extraction_staging_source_target_type",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id_name", "source_document"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kg_relationship_extraction_staging_nodes",
|
||||
"kg_relationship_extraction_staging",
|
||||
["source_node", "target_node"],
|
||||
)
|
||||
|
||||
# Create KGTerm table
|
||||
op.create_table(
|
||||
"kg_term",
|
||||
sa.Column("id_term", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column(
|
||||
"entity_types",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
)
|
||||
op.create_index("ix_search_term_entities", "kg_term", ["entity_types"])
|
||||
op.create_index("ix_search_term_term", "kg_term", ["id_term"])
|
||||
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("kg_stage", sa.String(), nullable=True, index=True),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("kg_processing_time", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"connector",
|
||||
sa.Column(
|
||||
"kg_processing_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"connector",
|
||||
sa.Column(
|
||||
"kg_coverage_days",
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
server_default=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Create GIN index for clustering and normalization
|
||||
op.execute(
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_kg_entity_clustering_trigrams "
|
||||
"ON kg_entity USING GIN (name public.gin_trgm_ops)"
|
||||
)
|
||||
op.execute(
|
||||
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_kg_entity_normalization_trigrams "
|
||||
"ON kg_entity USING GIN (name_trigrams)"
|
||||
)
|
||||
|
||||
# Create kg_entity trigger to update kg_entity.name and its trigrams
|
||||
alphanum_pattern = r"[^a-z0-9]+"
|
||||
truncate_length = 1000
|
||||
function = "update_kg_entity_name"
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE OR REPLACE FUNCTION {function}()
|
||||
RETURNS TRIGGER AS $$
|
||||
DECLARE
|
||||
name text;
|
||||
cleaned_name text;
|
||||
BEGIN
|
||||
-- Set name to semantic_id if document_id is not NULL
|
||||
IF NEW.document_id IS NOT NULL THEN
|
||||
SELECT lower(semantic_id) INTO name
|
||||
FROM document
|
||||
WHERE id = NEW.document_id;
|
||||
ELSE
|
||||
name = lower(NEW.name);
|
||||
END IF;
|
||||
|
||||
-- Clean name and truncate if too long
|
||||
cleaned_name = regexp_replace(
|
||||
name,
|
||||
'{alphanum_pattern}', '', 'g'
|
||||
);
|
||||
IF length(cleaned_name) > {truncate_length} THEN
|
||||
cleaned_name = left(cleaned_name, {truncate_length});
|
||||
END IF;
|
||||
|
||||
-- Set name and name trigrams
|
||||
NEW.name = name;
|
||||
NEW.name_trigrams = public.show_trgm(cleaned_name);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
"""
|
||||
)
|
||||
)
|
||||
trigger = f"{function}_trigger"
|
||||
op.execute(f"DROP TRIGGER IF EXISTS {trigger} ON kg_entity")
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE TRIGGER {trigger}
|
||||
BEFORE INSERT OR UPDATE OF name
|
||||
ON kg_entity
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION {function}();
|
||||
"""
|
||||
)
|
||||
|
||||
# Create kg_entity trigger to update kg_entity.name and its trigrams
|
||||
function = "update_kg_entity_name_from_doc"
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE OR REPLACE FUNCTION {function}()
|
||||
RETURNS TRIGGER AS $$
|
||||
DECLARE
|
||||
doc_name text;
|
||||
cleaned_name text;
|
||||
BEGIN
|
||||
doc_name = lower(NEW.semantic_id);
|
||||
|
||||
-- Clean name and truncate if too long
|
||||
cleaned_name = regexp_replace(
|
||||
doc_name,
|
||||
'{alphanum_pattern}', '', 'g'
|
||||
);
|
||||
IF length(cleaned_name) > {truncate_length} THEN
|
||||
cleaned_name = left(cleaned_name, {truncate_length});
|
||||
END IF;
|
||||
|
||||
-- Set name and name trigrams for all entities referencing this document
|
||||
UPDATE kg_entity
|
||||
SET
|
||||
name = doc_name,
|
||||
name_trigrams = public.show_trgm(cleaned_name)
|
||||
WHERE document_id = NEW.id;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
"""
|
||||
)
|
||||
)
|
||||
trigger = f"{function}_trigger"
|
||||
op.execute(f"DROP TRIGGER IF EXISTS {trigger} ON document")
|
||||
op.execute(
|
||||
f"""
|
||||
CREATE TRIGGER {trigger}
|
||||
AFTER UPDATE OF semantic_id
|
||||
ON document
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION {function}();
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
|
||||
# Drop all views that start with 'kg_'
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE
|
||||
view_name text;
|
||||
BEGIN
|
||||
FOR view_name IN
|
||||
SELECT c.relname
|
||||
FROM pg_catalog.pg_class c
|
||||
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relkind = 'v'
|
||||
AND n.nspname = current_schema()
|
||||
AND c.relname LIKE 'kg_relationships_with_access%'
|
||||
LOOP
|
||||
EXECUTE 'DROP VIEW IF EXISTS ' || quote_ident(view_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE
|
||||
view_name text;
|
||||
BEGIN
|
||||
FOR view_name IN
|
||||
SELECT c.relname
|
||||
FROM pg_catalog.pg_class c
|
||||
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relkind = 'v'
|
||||
AND n.nspname = current_schema()
|
||||
AND c.relname LIKE 'allowed_docs%'
|
||||
LOOP
|
||||
EXECUTE 'DROP VIEW IF EXISTS ' || quote_ident(view_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
|
||||
for table, function in (
|
||||
("kg_entity", "update_kg_entity_name"),
|
||||
("document", "update_kg_entity_name_from_doc"),
|
||||
):
|
||||
op.execute(f"DROP TRIGGER IF EXISTS {function}_trigger ON {table}")
|
||||
op.execute(f"DROP FUNCTION IF EXISTS {function}()")
|
||||
|
||||
# Drop index
|
||||
op.execute("COMMIT") # Commit to allow CONCURRENTLY
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_kg_entity_clustering_trigrams")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_kg_entity_normalization_trigrams")
|
||||
|
||||
# Drop tables in reverse order of creation to handle dependencies
|
||||
op.drop_table("kg_term")
|
||||
op.drop_table("kg_relationship")
|
||||
op.drop_table("kg_entity")
|
||||
op.drop_table("kg_relationship_type")
|
||||
op.drop_table("kg_relationship_extraction_staging")
|
||||
op.drop_table("kg_relationship_type_extraction_staging")
|
||||
op.drop_table("kg_entity_extraction_staging")
|
||||
op.drop_table("kg_entity_type")
|
||||
op.drop_column("connector", "kg_processing_enabled")
|
||||
op.drop_column("connector", "kg_coverage_days")
|
||||
op.drop_column("document", "kg_stage")
|
||||
op.drop_column("document", "kg_processing_time")
|
||||
op.drop_table("kg_config")
|
||||
|
||||
if not MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
else:
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA %I FROM %I', current_schema(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
"""add_db_readonly_user
|
||||
|
||||
Revision ID: 3b9f09038764
|
||||
Revises: 3b45e0018bf1
|
||||
Create Date: 2025-05-11 11:05:11.436977
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from alembic import op
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3b9f09038764"
|
||||
down_revision = "3b45e0018bf1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
|
||||
# the user is created in the standard migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -150,14 +150,12 @@ def research_object_source(
|
||||
),
|
||||
)
|
||||
]
|
||||
# fast_llm = graph_config.tooling.fast_llm
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
llm = primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
30,
|
||||
llm.invoke,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
|
||||
@@ -71,15 +71,12 @@ def consolidate_object_research(
|
||||
),
|
||||
)
|
||||
]
|
||||
graph_config.tooling.primary_llm
|
||||
# fast_llm = graph_config.tooling.fast_llm
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
llm = primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
30,
|
||||
llm.invoke,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
|
||||
|
||||
|
||||
class KGAnalysisPath(str, Enum):
|
||||
PROCESS_KG_ONLY_ANSWERS = "process_kg_only_answers"
|
||||
CONSTRUCT_DEEP_SEARCH_FILTERS = "construct_deep_search_filters"
|
||||
|
||||
|
||||
def simple_vs_search(
|
||||
state: MainState,
|
||||
) -> str:
|
||||
|
||||
identified_strategy = state.updated_strategy or state.strategy
|
||||
|
||||
if (
|
||||
identified_strategy == KGAnswerStrategy.DEEP
|
||||
or state.search_type == KGSearchType.SEARCH
|
||||
):
|
||||
return KGAnalysisPath.CONSTRUCT_DEEP_SEARCH_FILTERS.value
|
||||
else:
|
||||
return KGAnalysisPath.PROCESS_KG_ONLY_ANSWERS.value
|
||||
|
||||
|
||||
def research_individual_object(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable] | str:
|
||||
edge_start_time = datetime.now()
|
||||
|
||||
assert state.div_con_entities is not None
|
||||
assert state.broken_down_question is not None
|
||||
assert state.vespa_filter_results is not None
|
||||
|
||||
if (
|
||||
state.search_type == KGSearchType.SQL
|
||||
and state.strategy == KGAnswerStrategy.DEEP
|
||||
):
|
||||
|
||||
return [
|
||||
Send(
|
||||
"process_individual_deep_search",
|
||||
ResearchObjectInput(
|
||||
research_nr=research_nr + 1,
|
||||
entity=entity,
|
||||
broken_down_question=state.broken_down_question,
|
||||
vespa_filter_results=state.vespa_filter_results,
|
||||
source_division=state.source_division,
|
||||
source_entity_filters=state.source_filters,
|
||||
log_messages=[
|
||||
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
step_results=[],
|
||||
),
|
||||
)
|
||||
for research_nr, entity in enumerate(state.div_con_entities)
|
||||
]
|
||||
elif state.search_type == KGSearchType.SEARCH:
|
||||
return "filtered_search"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid combination of search type: {state.search_type} and strategy: {state.strategy}"
|
||||
)
|
||||
143
backend/onyx/agents/agent_search/kb_search/graph_builder.py
Normal file
143
backend/onyx/agents/agent_search/kb_search/graph_builder.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.kb_search.conditional_edges import (
|
||||
research_individual_object,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.conditional_edges import simple_vs_search
|
||||
from onyx.agents.agent_search.kb_search.nodes.a1_extract_ert import extract_ert
|
||||
from onyx.agents.agent_search.kb_search.nodes.a2_analyze import analyze
|
||||
from onyx.agents.agent_search.kb_search.nodes.a3_generate_simple_sql import (
|
||||
generate_simple_sql,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.b1_construct_deep_search_filters import (
|
||||
construct_deep_search_filters,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.b2p_process_individual_deep_search import (
|
||||
process_individual_deep_search,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.b2s_filtered_search import filtered_search
|
||||
from onyx.agents.agent_search.kb_search.nodes.b3_consolidate_individual_deep_search import (
|
||||
consolidate_individual_deep_search,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.c1_process_kg_only_answers import (
|
||||
process_kg_only_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.d1_generate_answer import generate_answer
|
||||
from onyx.agents.agent_search.kb_search.nodes.d2_logging_node import log_data
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def kb_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the knowledge graph search process.
|
||||
"""
|
||||
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
"extract_ert",
|
||||
extract_ert,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"generate_simple_sql",
|
||||
generate_simple_sql,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"filtered_search",
|
||||
filtered_search,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"analyze",
|
||||
analyze,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"generate_answer",
|
||||
generate_answer,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"log_data",
|
||||
log_data,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"construct_deep_search_filters",
|
||||
construct_deep_search_filters,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"process_individual_deep_search",
|
||||
process_individual_deep_search,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"consolidate_individual_deep_search",
|
||||
consolidate_individual_deep_search,
|
||||
)
|
||||
|
||||
graph.add_node("process_kg_only_answers", process_kg_only_answers)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="extract_ert")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="extract_ert",
|
||||
end_key="analyze",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="analyze",
|
||||
end_key="generate_simple_sql",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges("generate_simple_sql", simple_vs_search)
|
||||
|
||||
graph.add_edge(start_key="process_kg_only_answers", end_key="generate_answer")
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="construct_deep_search_filters",
|
||||
path=research_individual_object,
|
||||
path_map=["process_individual_deep_search", "filtered_search"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="process_individual_deep_search",
|
||||
end_key="consolidate_individual_deep_search",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="consolidate_individual_deep_search", end_key="generate_answer"
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="filtered_search",
|
||||
end_key="generate_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_answer",
|
||||
end_key="log_data",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="log_data",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
409
backend/onyx/agents/agent_search/kb_search/graph_utils.py
Normal file
409
backend/onyx/agents/agent_search/kb_search/graph_utils.py
Normal file
@@ -0,0 +1,409 @@
|
||||
import re
|
||||
from time import sleep
|
||||
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
|
||||
from onyx.agents.agent_search.kb_search.models import KGExpandedGraphObjects
|
||||
from onyx.agents.agent_search.kb_search.states import SubQuestionAnswerResults
|
||||
from onyx.agents.agent_search.kb_search.step_definitions import STEP_DESCRIPTIONS
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.document import get_kg_doc_info_for_entity_name
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_document_id_for_entity
|
||||
from onyx.db.entities import get_entity_name
|
||||
from onyx.db.entity_type import get_entity_types
|
||||
from onyx.kg.utils.formatting_utils import make_entity_id
|
||||
from onyx.kg.utils.formatting_utils import split_relationship_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _check_entities_disconnected(
|
||||
current_entities: list[str], current_relationships: list[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if all entities in current_entities are disconnected via the given relationships.
|
||||
Relationships are in the format: source_entity__relationship_name__target_entity
|
||||
|
||||
Args:
|
||||
current_entities: List of entity IDs to check connectivity for
|
||||
current_relationships: List of relationships in format source__relationship__target
|
||||
|
||||
Returns:
|
||||
bool: True if all entities are disconnected, False otherwise
|
||||
"""
|
||||
if not current_entities:
|
||||
return True
|
||||
|
||||
# Create a graph representation using adjacency list
|
||||
graph: dict[str, set[str]] = {entity: set() for entity in current_entities}
|
||||
|
||||
# Build the graph from relationships
|
||||
for relationship in current_relationships:
|
||||
try:
|
||||
source, _, target = split_relationship_id(relationship)
|
||||
if source in graph and target in graph:
|
||||
graph[source].add(target)
|
||||
# Add reverse edge to capture that we do also have a relationship in the other direction,
|
||||
# albeit not quite the same one.
|
||||
graph[target].add(source)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid relationship format: {relationship}")
|
||||
|
||||
# Use BFS to check if all entities are connected
|
||||
visited: set[str] = set()
|
||||
start_entity = current_entities[0]
|
||||
|
||||
def _bfs(start: str) -> None:
|
||||
queue = [start]
|
||||
visited.add(start)
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for neighbor in graph[current]:
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append(neighbor)
|
||||
|
||||
# Start BFS from the first entity
|
||||
_bfs(start_entity)
|
||||
|
||||
logger.debug(f"Number of visited entities: {len(visited)}")
|
||||
|
||||
# Check if all current_entities are in visited
|
||||
return not all(entity in visited for entity in current_entities)
|
||||
|
||||
|
||||
def create_minimal_connected_query_graph(
|
||||
entities: list[str], relationships: list[str], max_depth: int = 1
|
||||
) -> KGExpandedGraphObjects:
|
||||
"""
|
||||
TODO: Implement this. For now we'll trust the SQL generation to do the right thing.
|
||||
Return the original entities and relationships.
|
||||
"""
|
||||
return KGExpandedGraphObjects(entities=entities, relationships=relationships)
|
||||
|
||||
|
||||
def stream_write_step_description(
|
||||
writer: StreamWriter, step_nr: int, level: int = 0
|
||||
) -> None:
|
||||
|
||||
write_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=STEP_DESCRIPTIONS[step_nr].description,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
# Give the frontend a brief moment to catch up
|
||||
sleep(0.2)
|
||||
|
||||
|
||||
def stream_write_step_activities(
|
||||
writer: StreamWriter, step_nr: int, level: int = 0
|
||||
) -> None:
|
||||
for activity_nr, activity in enumerate(STEP_DESCRIPTIONS[step_nr].activities):
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=activity,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
query_id=activity_nr + 1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
|
||||
def stream_write_step_activity_explicit(
|
||||
writer: StreamWriter, step_nr: int, query_id: int, activity: str, level: int = 0
|
||||
) -> None:
|
||||
for activity in STEP_DESCRIPTIONS[step_nr].activities:
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=activity,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
query_id=query_id,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
|
||||
def stream_write_step_answer_explicit(
|
||||
writer: StreamWriter, step_nr: int, answer: str, level: int = 0
|
||||
) -> None:
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=answer,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
|
||||
def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
|
||||
for step_nr, step_detail in STEP_DESCRIPTIONS.items():
|
||||
|
||||
write_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=step_detail.description,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
for step_nr in STEP_DESCRIPTIONS.keys():
|
||||
|
||||
write_custom_event(
|
||||
"stream_finished",
|
||||
StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=0,
|
||||
)
|
||||
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_close_step_answer(
|
||||
writer: StreamWriter, step_nr: int, level: int = 0
|
||||
) -> None:
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_ANSWER,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_write_close_steps(writer: StreamWriter, level: int = 0) -> None:
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=level,
|
||||
)
|
||||
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_write_close_main_answer(writer: StreamWriter, level: int = 0) -> None:
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.MAIN_ANSWER,
|
||||
level=level,
|
||||
level_question_num=0,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_write_main_answer_token(
|
||||
writer: StreamWriter, token: str, level: int = 0, level_question_num: int = 0
|
||||
) -> None:
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=token, # No need to add space as tokenizer handles this
|
||||
level=level,
|
||||
level_question_num=level_question_num,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
|
||||
def get_doc_information_for_entity(entity_id_name: str) -> KGEntityDocInfo:
|
||||
"""
|
||||
Get document information for an entity, including its semantic name and document details.
|
||||
"""
|
||||
if "::" not in entity_id_name:
|
||||
return KGEntityDocInfo(
|
||||
doc_id=None,
|
||||
doc_semantic_id=None,
|
||||
doc_link=None,
|
||||
semantic_entity_name=entity_id_name,
|
||||
semantic_linked_entity_name=entity_id_name,
|
||||
)
|
||||
|
||||
entity_type, entity_name = map(str.strip, entity_id_name.split("::", 1))
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entity_document_id = get_document_id_for_entity(db_session, entity_id_name)
|
||||
if entity_document_id:
|
||||
return get_kg_doc_info_for_entity_name(
|
||||
db_session, entity_document_id, entity_type
|
||||
)
|
||||
else:
|
||||
entity_actual_name = get_entity_name(db_session, entity_id_name)
|
||||
|
||||
return KGEntityDocInfo(
|
||||
doc_id=None,
|
||||
doc_semantic_id=None,
|
||||
doc_link=None,
|
||||
semantic_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
|
||||
semantic_linked_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
|
||||
)
|
||||
|
||||
|
||||
def rename_entities_in_answer(answer: str) -> str:
|
||||
"""
|
||||
Process entity references in the answer string by:
|
||||
1. Extracting all strings matching <str>:<str> or <str>: <str> patterns
|
||||
2. Looking up these references in the entity table
|
||||
3. Replacing valid references with their corresponding values
|
||||
|
||||
Args:
|
||||
answer: The input string containing potential entity references
|
||||
|
||||
Returns:
|
||||
str: The processed string with entity references replaced
|
||||
"""
|
||||
logger.debug(f"Input answer: {answer}")
|
||||
|
||||
# Clean up any spaces around ::
|
||||
answer = re.sub(r"::\s+", "::", answer)
|
||||
logger.debug(f"After cleaning spaces: {answer}")
|
||||
|
||||
# Pattern to match entity_type::entity_name, with optional quotes
|
||||
pattern = r"(?:')?([a-zA-Z0-9-]+)::([a-zA-Z0-9]+)(?:')?"
|
||||
logger.debug(f"Using pattern: {pattern}")
|
||||
|
||||
matches = list(re.finditer(pattern, answer))
|
||||
logger.debug(f"Found {len(matches)} matches")
|
||||
|
||||
# get active entity types
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
active_entity_types = [
|
||||
x.id_name for x in get_entity_types(db_session, active=True)
|
||||
]
|
||||
logger.debug(f"Active entity types: {active_entity_types}")
|
||||
|
||||
# Create dictionary for processed references
|
||||
processed_refs = {}
|
||||
|
||||
for match in matches:
|
||||
entity_type = match.group(1).upper().strip()
|
||||
entity_name = match.group(2).strip()
|
||||
potential_entity_id_name = make_entity_id(entity_type, entity_name)
|
||||
logger.debug(f"Processing entity: {potential_entity_id_name}")
|
||||
|
||||
if entity_type not in active_entity_types:
|
||||
logger.debug(f"Entity type {entity_type} not in active types")
|
||||
continue
|
||||
|
||||
replacement_candidate = get_doc_information_for_entity(potential_entity_id_name)
|
||||
|
||||
if replacement_candidate.doc_id:
|
||||
# Store both the original match and the entity_id_name for replacement
|
||||
processed_refs[match.group(0)] = (
|
||||
replacement_candidate.semantic_linked_entity_name
|
||||
)
|
||||
logger.debug(
|
||||
f"Added replacement: {match.group(0)} -> {replacement_candidate.semantic_linked_entity_name}"
|
||||
)
|
||||
else:
|
||||
processed_refs[match.group(0)] = replacement_candidate.semantic_entity_name
|
||||
logger.debug(
|
||||
f"Added replacement: {match.group(0)} -> {replacement_candidate.semantic_entity_name}"
|
||||
)
|
||||
|
||||
# Replace all references in the answer
|
||||
for ref, replacement in processed_refs.items():
|
||||
answer = answer.replace(ref, replacement)
|
||||
logger.debug(f"Replaced {ref} with {replacement}")
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
def build_document_context(
|
||||
document: InferenceSection | LlmDoc, document_number: int
|
||||
) -> str:
|
||||
"""
|
||||
Build a context string for a document.
|
||||
"""
|
||||
|
||||
metadata_list: list[str] = []
|
||||
document_content: str | None = None
|
||||
info_source: InferenceChunk | LlmDoc | None = None
|
||||
info_content: str | None = None
|
||||
|
||||
if isinstance(document, InferenceSection):
|
||||
info_source = document.center_chunk
|
||||
info_content = document.combined_content
|
||||
elif isinstance(document, LlmDoc):
|
||||
info_source = document
|
||||
info_content = document.content
|
||||
|
||||
for key, value in info_source.metadata.items():
|
||||
metadata_list.append(f" - {key}: {value}")
|
||||
|
||||
if metadata_list:
|
||||
metadata_str = "- Document Metadata:\n" + "\n".join(metadata_list)
|
||||
else:
|
||||
metadata_str = ""
|
||||
|
||||
# Construct document header with number and semantic identifier
|
||||
doc_header = f"Document {str(document_number)}: {info_source.semantic_identifier}"
|
||||
|
||||
# Combine all parts with proper spacing
|
||||
document_content = f"{doc_header}\n\n{metadata_str}\n\n{info_content}"
|
||||
|
||||
return document_content
|
||||
|
||||
|
||||
def get_near_empty_step_results(
|
||||
step_number: int,
|
||||
step_answer: str,
|
||||
verified_reranked_documents: list[InferenceSection] = [],
|
||||
) -> SubQuestionAnswerResults:
|
||||
"""
|
||||
Get near-empty step results from a list of step results.
|
||||
"""
|
||||
return SubQuestionAnswerResults(
|
||||
question=STEP_DESCRIPTIONS[step_number].description,
|
||||
question_id="0_" + str(step_number),
|
||||
answer=step_answer,
|
||||
verified_high_quality=True,
|
||||
sub_query_retrieval_results=[],
|
||||
verified_reranked_documents=verified_reranked_documents,
|
||||
context_documents=[],
|
||||
cited_documents=[],
|
||||
sub_question_retrieval_stats=AgentChunkRetrievalStats(
|
||||
verified_count=None,
|
||||
verified_avg_scores=None,
|
||||
rejected_count=None,
|
||||
rejected_avg_scores=None,
|
||||
verified_doc_chunk_ids=[],
|
||||
dismissed_doc_chunk_ids=[],
|
||||
),
|
||||
)
|
||||
49
backend/onyx/agents/agent_search/kb_search/models.py
Normal file
49
backend/onyx/agents/agent_search/kb_search/models.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
from onyx.agents.agent_search.kb_search.states import YesNoEnum
|
||||
|
||||
|
||||
class KGQuestionEntityExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
terms: list[str]
|
||||
time_filter: str | None
|
||||
|
||||
|
||||
class KGAnswerApproach(BaseModel):
|
||||
search_type: KGSearchType
|
||||
search_strategy: KGAnswerStrategy
|
||||
format: KGAnswerFormat
|
||||
broken_down_question: str | None = None
|
||||
divide_and_conquer: YesNoEnum | None = None
|
||||
|
||||
|
||||
class KGQuestionRelationshipExtractionResult(BaseModel):
|
||||
relationships: list[str]
|
||||
|
||||
|
||||
class KGQuestionExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
terms: list[str]
|
||||
time_filter: str | None
|
||||
|
||||
|
||||
class KGExpandedGraphObjects(BaseModel):
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
|
||||
|
||||
class KGSteps(BaseModel):
|
||||
description: str
|
||||
activities: list[str]
|
||||
|
||||
|
||||
class KGEntityDocInfo(BaseModel):
|
||||
doc_id: str | None
|
||||
doc_semantic_id: str | None
|
||||
doc_link: str | None
|
||||
semantic_entity_name: str
|
||||
semantic_linked_entity_name: str
|
||||
@@ -0,0 +1,274 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_structure
|
||||
from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
|
||||
from onyx.agents.agent_search.kb_search.models import (
|
||||
KGQuestionRelationshipExtractionResult,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import ERTExtractionUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.kg_configs import KG_ENTITY_EXTRACTION_TIMEOUT
|
||||
from onyx.configs.kg_configs import KG_RELATIONSHIP_EXTRACTION_TIMEOUT
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.kg_temp_view import create_views
|
||||
from onyx.db.kg_temp_view import get_user_view_names
|
||||
from onyx.db.relationships import get_allowed_relationship_type_pairs
|
||||
from onyx.kg.extractions.extraction_processing import get_entity_types_str
|
||||
from onyx.kg.extractions.extraction_processing import get_relationship_types_str
|
||||
from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def extract_ert(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ERTExtractionUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
# recheck KG enablement at outset KG graph
|
||||
|
||||
if not config["metadata"]["config"].behavior.kg_config_settings.KG_ENABLED:
|
||||
logger.error("KG approach is not enabled, the KG agent flow cannot run.")
|
||||
raise ValueError("KG approach is not enabled, the KG agent flow cannot run.")
|
||||
|
||||
_KG_STEP_NR = 1
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
if graph_config.tooling.search_tool is None:
|
||||
raise ValueError("Search tool is not set")
|
||||
elif graph_config.tooling.search_tool.user is None:
|
||||
raise ValueError("User is not set")
|
||||
else:
|
||||
user_email = graph_config.tooling.search_tool.user.email
|
||||
user_name = user_email.split("@")[0] or "unknown"
|
||||
|
||||
# first four lines duplicates from generate_initial_answer
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
today_date = datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
all_entity_types = get_entity_types_str(active=True)
|
||||
all_relationship_types = get_relationship_types_str(active=True)
|
||||
|
||||
# Stream structure of substeps out to the UI
|
||||
stream_write_step_structure(writer)
|
||||
|
||||
# Now specify core activities in the step (step 1)
|
||||
stream_write_step_activities(writer, _KG_STEP_NR)
|
||||
|
||||
# Create temporary views. TODO: move into parallel step, if ultimately materialized
|
||||
allowed_docs_view_name, kg_relationships_view_name = get_user_view_names(user_email)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_views(
|
||||
db_session,
|
||||
user_email=user_email,
|
||||
allowed_docs_view_name=allowed_docs_view_name,
|
||||
kg_relationships_view_name=kg_relationships_view_name,
|
||||
)
|
||||
|
||||
### get the entities, terms, and filters
|
||||
|
||||
query_extraction_pre_prompt = QUERY_ENTITY_EXTRACTION_PROMPT.format(
|
||||
entity_types=all_entity_types,
|
||||
relationship_types=all_relationship_types,
|
||||
)
|
||||
|
||||
query_extraction_prompt = (
|
||||
query_extraction_pre_prompt.replace("---content---", question)
|
||||
.replace("---today_date---", today_date)
|
||||
.replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
.replace("{{", "{")
|
||||
.replace("}}", "}")
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=query_extraction_prompt,
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_ENTITY_EXTRACTION_TIMEOUT,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("{{", "{")
|
||||
.replace("}}", "}")
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
entity_extraction_result = (
|
||||
KGQuestionEntityExtractionResult.model_validate_json(cleaned_response)
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
entity_extraction_result = KGQuestionEntityExtractionResult(
|
||||
entities=[],
|
||||
terms=[],
|
||||
time_filter="",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
entity_extraction_result = KGQuestionEntityExtractionResult(
|
||||
entities=[],
|
||||
terms=[],
|
||||
time_filter="",
|
||||
)
|
||||
|
||||
# remove the attribute filters from the entities to for the purpose of the relationship
|
||||
entities_no_attributes = [
|
||||
entity.split("--")[0] for entity in entity_extraction_result.entities
|
||||
]
|
||||
ert_entities_string = f"Entities: {entities_no_attributes}\n"
|
||||
|
||||
### get the relationships
|
||||
|
||||
# find the relationship types that match the extracted entity types
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
allowed_relationship_pairs = get_allowed_relationship_type_pairs(
|
||||
db_session, entity_extraction_result.entities
|
||||
)
|
||||
|
||||
query_relationship_extraction_prompt = (
|
||||
QUERY_RELATIONSHIP_EXTRACTION_PROMPT.replace("---question---", question)
|
||||
.replace("---today_date---", today_date)
|
||||
.replace(
|
||||
"---relationship_type_options---",
|
||||
" - " + "\n - ".join(allowed_relationship_pairs),
|
||||
)
|
||||
.replace("---identified_entities---", ert_entities_string)
|
||||
.replace("---entity_types---", all_entity_types)
|
||||
.replace("{{", "{")
|
||||
.replace("}}", "}")
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=query_relationship_extraction_prompt,
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_RELATIONSHIP_EXTRACTION_TIMEOUT,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("{{", "{")
|
||||
.replace("}}", "}")
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
cleaned_response = cleaned_response.replace("{{", '{"')
|
||||
cleaned_response = cleaned_response.replace("}}", '"}')
|
||||
|
||||
try:
|
||||
relationship_extraction_result = (
|
||||
KGQuestionRelationshipExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
|
||||
relationships=[],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
|
||||
relationships=[],
|
||||
)
|
||||
|
||||
## STEP 1
|
||||
# Stream answer pieces out to the UI for Step 1
|
||||
|
||||
extracted_entity_string = " \n ".join(
|
||||
[x.split("--")[0] for x in entity_extraction_result.entities]
|
||||
)
|
||||
extracted_relationship_string = " \n ".join(
|
||||
relationship_extraction_result.relationships
|
||||
)
|
||||
|
||||
step_answer = f"""Entities and relationships have been extracted from query - \n \
|
||||
Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_string}"""
|
||||
|
||||
stream_write_step_answer_explicit(writer, step_nr=1, answer=step_answer)
|
||||
|
||||
# Finish Step 1
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
return ERTExtractionUpdate(
|
||||
entities_types_str=all_entity_types,
|
||||
relationship_types_str=all_relationship_types,
|
||||
extracted_entities_w_attributes=entity_extraction_result.entities,
|
||||
extracted_entities_no_attributes=entities_no_attributes,
|
||||
extracted_relationships=relationship_extraction_result.relationships,
|
||||
extracted_terms=entity_extraction_result.terms,
|
||||
time_filter=entity_extraction_result.time_filter,
|
||||
kg_doc_temp_view_name=allowed_docs_view_name,
|
||||
kg_rel_temp_view_name=kg_relationships_view_name,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="extract entities terms",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR,
|
||||
step_answer=step_answer,
|
||||
verified_reranked_documents=[],
|
||||
)
|
||||
],
|
||||
)
|
||||
311
backend/onyx/agents/agent_search/kb_search/nodes/a2_analyze.py
Normal file
311
backend/onyx/agents/agent_search/kb_search/nodes/a2_analyze.py
Normal file
@@ -0,0 +1,311 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
create_minimal_connected_query_graph,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
|
||||
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import YesNoEnum
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.kg_configs import KG_STRATEGY_GENERATION_TIMEOUT
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_document_id_for_entity
|
||||
from onyx.kg.clustering.normalizations import normalize_entities
|
||||
from onyx.kg.clustering.normalizations import normalize_entities_w_attributes_from_map
|
||||
from onyx.kg.clustering.normalizations import normalize_relationships
|
||||
from onyx.kg.clustering.normalizations import normalize_terms
|
||||
from onyx.kg.utils.formatting_utils import split_relationship_id
|
||||
from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _articulate_normalizations(
|
||||
entity_normalization_map: dict[str, str],
|
||||
relationship_normalization_map: dict[str, str],
|
||||
) -> str:
|
||||
|
||||
remark_list: list[str] = []
|
||||
|
||||
if entity_normalization_map:
|
||||
remark_list.append("\n Entities:")
|
||||
for extracted_entity, normalized_entity in entity_normalization_map.items():
|
||||
remark_list.append(f" - {extracted_entity} -> {normalized_entity}")
|
||||
|
||||
if relationship_normalization_map:
|
||||
remark_list.append(" \n Relationships:")
|
||||
for (
|
||||
extracted_relationship,
|
||||
normalized_relationship,
|
||||
) in relationship_normalization_map.items():
|
||||
remark_list.append(
|
||||
f" - {extracted_relationship} -> {normalized_relationship}"
|
||||
)
|
||||
|
||||
return " \n ".join(remark_list)
|
||||
|
||||
|
||||
def _get_fully_connected_entities(
|
||||
entities: list[str], relationships: list[str]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Analyze the connectedness of the entities and relationships.
|
||||
"""
|
||||
# Build a dictionary to track connections for each entity
|
||||
entity_connections: dict[str, set[str]] = {entity: set() for entity in entities}
|
||||
|
||||
# Parse relationships to build connection graph
|
||||
for relationship in relationships:
|
||||
# Split relationship into parts. Test for proper formatting just in case.
|
||||
# Should never be an error though at this point.
|
||||
parts = split_relationship_id(relationship)
|
||||
if len(parts) != 3:
|
||||
raise ValueError(f"Invalid relationship: {relationship}")
|
||||
|
||||
entity1 = parts[0]
|
||||
entity2 = parts[2]
|
||||
|
||||
# Add bidirectional connections
|
||||
if entity1 in entity_connections:
|
||||
entity_connections[entity1].add(entity2)
|
||||
if entity2 in entity_connections:
|
||||
entity_connections[entity2].add(entity1)
|
||||
|
||||
# Find entities connected to all others
|
||||
fully_connected_entities = []
|
||||
all_entities = set(entities)
|
||||
|
||||
for entity, connections in entity_connections.items():
|
||||
# Check if this entity is connected to all other entities
|
||||
if connections == all_entities - {entity}:
|
||||
fully_connected_entities.append(entity)
|
||||
|
||||
return fully_connected_entities
|
||||
|
||||
|
||||
def _check_for_single_doc(
|
||||
normalized_entities: list[str],
|
||||
raw_entities: list[str],
|
||||
normalized_relationship_strings: list[str],
|
||||
raw_relationships: list[str],
|
||||
normalized_time_filter: str | None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Check if the query is for a single document, like 'Summarize ticket ENG-2243K'.
|
||||
None is returned if the query is not for a single document.
|
||||
"""
|
||||
if (
|
||||
len(normalized_entities) == 1
|
||||
and len(raw_entities) == 1
|
||||
and len(normalized_relationship_strings) == 0
|
||||
and len(raw_relationships) == 0
|
||||
and normalized_time_filter is None
|
||||
):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
single_doc_id = get_document_id_for_entity(
|
||||
db_session, normalized_entities[0]
|
||||
)
|
||||
else:
|
||||
single_doc_id = None
|
||||
return single_doc_id
|
||||
|
||||
|
||||
def analyze(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> AnalysisUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
_KG_STEP_NR = 2
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
entities = (
|
||||
state.extracted_entities_no_attributes
|
||||
) # attribute knowledge is not required for this step
|
||||
relationships = state.extracted_relationships
|
||||
terms = state.extracted_terms
|
||||
time_filter = state.time_filter
|
||||
|
||||
## STEP 2 - stream out goals
|
||||
|
||||
stream_write_step_activities(writer, _KG_STEP_NR)
|
||||
|
||||
# Continue with node
|
||||
|
||||
normalized_entities = normalize_entities(
|
||||
entities, allowed_docs_temp_view_name=state.kg_doc_temp_view_name
|
||||
)
|
||||
|
||||
query_graph_entities_w_attributes = normalize_entities_w_attributes_from_map(
|
||||
state.extracted_entities_w_attributes,
|
||||
normalized_entities.entity_normalization_map,
|
||||
)
|
||||
|
||||
normalized_relationships = normalize_relationships(
|
||||
relationships, normalized_entities.entity_normalization_map
|
||||
)
|
||||
normalized_terms = normalize_terms(terms)
|
||||
normalized_time_filter = time_filter
|
||||
|
||||
# If single-doc inquiry, send to single-doc processing directly
|
||||
|
||||
single_doc_id = _check_for_single_doc(
|
||||
normalized_entities=normalized_entities.entities,
|
||||
raw_entities=entities,
|
||||
normalized_relationship_strings=normalized_relationships.relationships,
|
||||
raw_relationships=relationships,
|
||||
normalized_time_filter=normalized_time_filter,
|
||||
)
|
||||
|
||||
# Expand the entities and relationships to make sure that entities are connected
|
||||
|
||||
graph_expansion = create_minimal_connected_query_graph(
|
||||
normalized_entities.entities,
|
||||
normalized_relationships.relationships,
|
||||
max_depth=2,
|
||||
)
|
||||
|
||||
query_graph_entities = graph_expansion.entities
|
||||
query_graph_relationships = graph_expansion.relationships
|
||||
|
||||
# Evaluate whether a search needs to be done after identifying all entities and relationships
|
||||
|
||||
strategy_generation_prompt = (
|
||||
STRATEGY_GENERATION_PROMPT.replace(
|
||||
"---entities---", "\n".join(query_graph_entities)
|
||||
)
|
||||
.replace("---relationships---", "\n".join(query_graph_relationships))
|
||||
.replace("---possible_entities---", state.entities_types_str)
|
||||
.replace("---possible_relationships---", state.relationship_types_str)
|
||||
.replace("---question---", question)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=strategy_generation_prompt,
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_STRATEGY_GENERATION_TIMEOUT,
|
||||
# fast_llm.invoke,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=5,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
approach_extraction_result = KGAnswerApproach.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
search_type = approach_extraction_result.search_type
|
||||
search_strategy = approach_extraction_result.search_strategy
|
||||
output_format = approach_extraction_result.format
|
||||
broken_down_question = approach_extraction_result.broken_down_question
|
||||
divide_and_conquer = approach_extraction_result.divide_and_conquer
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
search_type = KGSearchType.SEARCH
|
||||
search_strategy = KGAnswerStrategy.DEEP
|
||||
output_format = KGAnswerFormat.TEXT
|
||||
broken_down_question = None
|
||||
divide_and_conquer = YesNoEnum.NO
|
||||
if search_strategy is None or output_format is None:
|
||||
raise ValueError(f"Invalid strategy: {cleaned_response}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in strategy generation: {e}")
|
||||
raise e
|
||||
|
||||
# Stream out relevant results
|
||||
|
||||
if single_doc_id:
|
||||
search_strategy = (
|
||||
KGAnswerStrategy.DEEP
|
||||
) # if a single doc is identified, we will want to look at the details.
|
||||
|
||||
step_answer = f"Strategy and format have been extracted from query. Strategy: {search_strategy.value}, \
|
||||
Format: {output_format.value}, Broken down question: {broken_down_question}"
|
||||
|
||||
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
|
||||
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
# End node
|
||||
|
||||
return AnalysisUpdate(
|
||||
normalized_core_entities=normalized_entities.entities,
|
||||
normalized_core_relationships=normalized_relationships.relationships,
|
||||
entity_normalization_map=normalized_entities.entity_normalization_map,
|
||||
relationship_normalization_map=normalized_relationships.relationship_normalization_map,
|
||||
query_graph_entities_no_attributes=query_graph_entities,
|
||||
query_graph_entities_w_attributes=query_graph_entities_w_attributes,
|
||||
query_graph_relationships=query_graph_relationships,
|
||||
normalized_terms=normalized_terms.terms,
|
||||
normalized_time_filter=normalized_time_filter,
|
||||
strategy=search_strategy,
|
||||
broken_down_question=broken_down_question,
|
||||
output_format=output_format,
|
||||
divide_and_conquer=divide_and_conquer,
|
||||
single_doc_id=single_doc_id,
|
||||
search_type=search_type,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="analyze",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR,
|
||||
step_answer=step_answer,
|
||||
verified_reranked_documents=[],
|
||||
)
|
||||
],
|
||||
remarks=[
|
||||
_articulate_normalizations(
|
||||
entity_normalization_map=normalized_entities.entity_normalization_map,
|
||||
relationship_normalization_map=normalized_relationships.relationship_normalization_map,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,415 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import SQLSimpleGenerationUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from onyx.configs.kg_configs import KG_MAX_DEEP_SEARCH_RESULTS
|
||||
from onyx.configs.kg_configs import KG_SQL_GENERATION_TIMEOUT
|
||||
from onyx.db.engine import get_db_readonly_user_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.kg_temp_view import drop_views
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_CORRECTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_PROMPT
|
||||
from onyx.prompts.kg_prompts import SOURCE_DETECTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _drop_temp_views(
|
||||
allowed_docs_view_name: str, kg_relationships_view_name: str
|
||||
) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
drop_views(
|
||||
db_session,
|
||||
allowed_docs_view_name=allowed_docs_view_name,
|
||||
kg_relationships_view_name=kg_relationships_view_name,
|
||||
)
|
||||
|
||||
|
||||
def _build_entity_explanation_str(entity_normalization_map: dict[str, str]) -> str:
|
||||
"""
|
||||
Build a string of contextualized entities to avoid the model not being aware of
|
||||
what eg ACCOUNT::SF_8254Hs means as a normalized entity
|
||||
"""
|
||||
entity_explanation_components = []
|
||||
for entity, normalized_entity in entity_normalization_map.items():
|
||||
entity_explanation_components.append(f" - {entity} -> {normalized_entity}")
|
||||
return "\n".join(entity_explanation_components)
|
||||
|
||||
|
||||
def _sql_is_aggregate_query(sql_statement: str) -> bool:
|
||||
return any(
|
||||
agg_func in sql_statement.upper()
|
||||
for agg_func in ["COUNT(", "MAX(", "MIN(", "AVG(", "SUM("]
|
||||
)
|
||||
|
||||
|
||||
def _get_source_documents(
|
||||
sql_statement: str,
|
||||
llm: LLM,
|
||||
allowed_docs_view_name: str,
|
||||
kg_relationships_view_name: str,
|
||||
) -> str | None:
|
||||
"""
|
||||
Generate SQL to retrieve source documents based on the input sql statement.
|
||||
"""
|
||||
|
||||
source_detection_prompt = SOURCE_DETECTION_PROMPT.replace(
|
||||
"---original_sql_statement---", sql_statement
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=source_detection_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
cleaned_response: str | None = None
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_SQL_GENERATION_TIMEOUT,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=25,
|
||||
max_tokens=1200,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
sql_statement = cleaned_response.split("<sql>")[1].strip()
|
||||
sql_statement = sql_statement.split("</sql>")[0].strip()
|
||||
|
||||
except Exception as e:
|
||||
if cleaned_response is not None:
|
||||
logger.error(
|
||||
f"Could not generate source documents SQL: {e}. Original model response: {cleaned_response}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Could not generate source documents SQL: {e}")
|
||||
|
||||
return None
|
||||
|
||||
return sql_statement
|
||||
|
||||
|
||||
def generate_simple_sql(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> SQLSimpleGenerationUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
_KG_STEP_NR = 3
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
entities_types_str = state.entities_types_str
|
||||
relationship_types_str = state.relationship_types_str
|
||||
|
||||
single_doc_id = state.single_doc_id
|
||||
|
||||
if state.kg_doc_temp_view_name is None:
|
||||
raise ValueError("kg_doc_temp_view_name is not set")
|
||||
if state.kg_rel_temp_view_name is None:
|
||||
raise ValueError("kg_rel_temp_view_name is not set")
|
||||
|
||||
## STEP 3 - articulate goals
|
||||
|
||||
stream_write_step_activities(writer, _KG_STEP_NR)
|
||||
|
||||
if graph_config.tooling.search_tool is None:
|
||||
raise ValueError("Search tool is not set")
|
||||
elif graph_config.tooling.search_tool.user is None:
|
||||
raise ValueError("User is not set")
|
||||
else:
|
||||
user_email = graph_config.tooling.search_tool.user.email
|
||||
user_name = user_email.split("@")[0]
|
||||
|
||||
if state.search_type == KGSearchType.SQL and single_doc_id:
|
||||
|
||||
# If single doc id already identified, we do not need to go through the KG
|
||||
# query cycle, saving a lot of time.
|
||||
|
||||
main_sql_statement: str | None = None
|
||||
query_results: list[dict[str, Any]] | None = None
|
||||
source_documents_sql: str | None = None
|
||||
source_document_results: list[str] | None = [single_doc_id]
|
||||
reasoning: str | None = (
|
||||
f"A KG query was not required as the source document was already identified: {single_doc_id}"
|
||||
)
|
||||
|
||||
step_answer = f"Source document already identified: {single_doc_id}"
|
||||
|
||||
elif state.search_type == KGSearchType.SEARCH:
|
||||
# If we do a filtered search, then we do not need to go through the SQL
|
||||
# generation process.
|
||||
|
||||
main_sql_statement = None
|
||||
query_results = None
|
||||
source_documents_sql = None
|
||||
source_document_results = None
|
||||
reasoning = "A KG query was not required as we will use a filtered search."
|
||||
|
||||
step_answer = "Filtered search will be used."
|
||||
|
||||
else:
|
||||
# If no single doc id already identified, we need to go through the KG
|
||||
# query cycle, including generating the SQL for the answer and the sources
|
||||
|
||||
# Build prompt
|
||||
|
||||
# First, create string of contextualized entities to avoid the model not
|
||||
# being aware of what eg ACCOUNT::SF_8254Hs means as a normalized entity
|
||||
|
||||
entity_explanation_str = _build_entity_explanation_str(
|
||||
state.entity_normalization_map
|
||||
)
|
||||
|
||||
current_tenant = get_current_tenant_id()
|
||||
|
||||
current_tenant_view_name = f'"{current_tenant}".{state.kg_doc_temp_view_name}'
|
||||
current_tenant_rel_view_name = f'"{current_tenant}".{state.kg_rel_temp_view_name}'
|
||||
|
||||
simple_sql_prompt = (
|
||||
SIMPLE_SQL_PROMPT.replace("---entity_types---", entities_types_str)
|
||||
.replace("---relationship_types---", relationship_types_str)
|
||||
.replace("---question---", question)
|
||||
.replace("---entity_explanation_string---", entity_explanation_str)
|
||||
.replace(
|
||||
"---query_entities_with_attributes---",
|
||||
"\n".join(state.query_graph_entities_w_attributes),
|
||||
)
|
||||
.replace(
|
||||
"---query_relationships---", "\n".join(state.query_graph_relationships)
|
||||
)
|
||||
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
|
||||
.replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
)
|
||||
|
||||
# prepare SQL query generation
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=simple_sql_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_SQL_GENERATION_TIMEOUT,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=25,
|
||||
max_tokens=1500,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
sql_statement = (
|
||||
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
|
||||
)
|
||||
sql_statement = sql_statement.split(";")[0].strip() + ";"
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
sql_statement = sql_statement.replace(
|
||||
"kg_relationship", current_tenant_rel_view_name
|
||||
)
|
||||
|
||||
reasoning = (
|
||||
cleaned_response.split("<reasoning>")[1]
|
||||
.strip()
|
||||
.split("</reasoning>")[0]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in strategy generation: {e}")
|
||||
|
||||
_drop_temp_views(
|
||||
allowed_docs_view_name=current_tenant_view_name,
|
||||
kg_relationships_view_name=current_tenant_rel_view_name,
|
||||
)
|
||||
raise e
|
||||
|
||||
logger.debug(f"A3 - sql_statement: {sql_statement}")
|
||||
|
||||
# Correction if needed:
|
||||
|
||||
correction_prompt = SIMPLE_SQL_CORRECTION_PROMPT.replace(
|
||||
"---draft_sql---", sql_statement
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=correction_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_SQL_GENERATION_TIMEOUT,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=25,
|
||||
max_tokens=1500,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
|
||||
sql_statement = (
|
||||
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
|
||||
)
|
||||
sql_statement = sql_statement.split(";")[0].strip() + ";"
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in generating the sql correction: {e}. Original model response: {cleaned_response}"
|
||||
)
|
||||
|
||||
_drop_temp_views(
|
||||
allowed_docs_view_name=current_tenant_view_name,
|
||||
kg_relationships_view_name=current_tenant_rel_view_name,
|
||||
)
|
||||
|
||||
raise e
|
||||
|
||||
logger.debug(f"A3 - sql_statement after correction: {sql_statement}")
|
||||
|
||||
# Get SQL for source documents
|
||||
|
||||
source_documents_sql = _get_source_documents(
|
||||
sql_statement,
|
||||
llm=primary_llm,
|
||||
allowed_docs_view_name=current_tenant_view_name,
|
||||
kg_relationships_view_name=current_tenant_rel_view_name,
|
||||
)
|
||||
|
||||
logger.info(f"A3 source_documents_sql: {source_documents_sql}")
|
||||
|
||||
scalar_result = None
|
||||
query_results = None
|
||||
|
||||
with get_db_readonly_user_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
result = db_session.execute(text(sql_statement))
|
||||
# Handle scalar results (like COUNT)
|
||||
if sql_statement.upper().startswith("SELECT COUNT"):
|
||||
scalar_result = result.scalar()
|
||||
query_results = (
|
||||
[{"count": int(scalar_result)}]
|
||||
if scalar_result is not None
|
||||
else []
|
||||
)
|
||||
else:
|
||||
# Handle regular row results
|
||||
rows = result.fetchall()
|
||||
query_results = [dict(row._mapping) for row in rows]
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing SQL query: {e}")
|
||||
|
||||
raise e
|
||||
|
||||
source_document_results = None
|
||||
if source_documents_sql is not None and source_documents_sql != sql_statement:
|
||||
with get_db_readonly_user_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
result = db_session.execute(text(source_documents_sql))
|
||||
rows = result.fetchall()
|
||||
query_source_document_results = [dict(row._mapping) for row in rows]
|
||||
source_document_results = [
|
||||
source_document_result["source_document"]
|
||||
for source_document_result in query_source_document_results
|
||||
]
|
||||
except Exception as e:
|
||||
# No stopping here, the individualized SQL query is not mandatory
|
||||
logger.error(f"Error executing Individualized SQL query: {e}")
|
||||
# individualized_query_results = None
|
||||
|
||||
else:
|
||||
source_document_results = None
|
||||
|
||||
_drop_temp_views(
|
||||
allowed_docs_view_name=state.kg_doc_temp_view_name,
|
||||
kg_relationships_view_name=state.kg_rel_temp_view_name,
|
||||
)
|
||||
|
||||
logger.info(f"A3 - Number of query_results: {len(query_results)}")
|
||||
|
||||
# Stream out reasoning and SQL query
|
||||
|
||||
step_answer = f"Reasoning: {reasoning} \n \n SQL Query: {sql_statement}"
|
||||
|
||||
main_sql_statement = sql_statement
|
||||
|
||||
if reasoning:
|
||||
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=reasoning)
|
||||
|
||||
if main_sql_statement:
|
||||
stream_write_step_answer_explicit(
|
||||
writer,
|
||||
step_nr=_KG_STEP_NR,
|
||||
answer=f" \n Generated SQL: {main_sql_statement}",
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
# Update path if too many results are retrieved
|
||||
|
||||
if query_results and len(query_results) > KG_MAX_DEEP_SEARCH_RESULTS:
|
||||
updated_strategy = KGAnswerStrategy.SIMPLE
|
||||
else:
|
||||
updated_strategy = None
|
||||
|
||||
return SQLSimpleGenerationUpdate(
|
||||
sql_query=main_sql_statement,
|
||||
sql_query_results=query_results,
|
||||
individualized_sql_query=None,
|
||||
individualized_query_results=None,
|
||||
source_documents_sql=source_documents_sql,
|
||||
source_document_results=source_document_results or [],
|
||||
updated_strategy=updated_strategy,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="generate simple sql",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR,
|
||||
step_answer=step_answer,
|
||||
verified_reranked_documents=[],
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,186 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import KGFilterConstructionResults
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.kg_configs import KG_FILTER_CONSTRUCTION_TIMEOUT
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entity_type import get_entity_types_with_grounded_source_name
|
||||
from onyx.kg.utils.formatting_utils import make_entity_id
|
||||
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def construct_deep_search_filters(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter
|
||||
) -> DeepSearchFilterUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
entities_types_str = state.entities_types_str
|
||||
entities = state.query_graph_entities_no_attributes
|
||||
relationships = state.query_graph_relationships
|
||||
simple_sql_query = state.sql_query
|
||||
simple_sql_results = state.sql_query_results
|
||||
source_document_results = state.source_document_results
|
||||
if simple_sql_results:
|
||||
simple_sql_results_str = "\n".join([str(x) for x in simple_sql_results])
|
||||
else:
|
||||
simple_sql_results_str = "(no SQL results generated)"
|
||||
if source_document_results:
|
||||
source_document_results_str = "\n".join(
|
||||
[str(x) for x in source_document_results]
|
||||
)
|
||||
else:
|
||||
source_document_results_str = "(no source document results generated)"
|
||||
|
||||
logger.info(
|
||||
f"B1 - characters in source_document_results_str: len{source_document_results_str}"
|
||||
)
|
||||
|
||||
search_filter_construction_prompt = (
|
||||
SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
|
||||
"---entity_type_descriptions---",
|
||||
entities_types_str,
|
||||
)
|
||||
.replace(
|
||||
"---entity_filters---",
|
||||
"\n".join(entities),
|
||||
)
|
||||
.replace(
|
||||
"---relationship_filters---",
|
||||
"\n".join(relationships),
|
||||
)
|
||||
.replace(
|
||||
"---sql_query---",
|
||||
simple_sql_query or "(no SQL generated)",
|
||||
)
|
||||
.replace(
|
||||
"---sql_results---",
|
||||
simple_sql_results_str or "(no SQL results generated)",
|
||||
)
|
||||
.replace(
|
||||
"---source_document_results---",
|
||||
source_document_results_str or "(no source document results generated)",
|
||||
)
|
||||
.replace(
|
||||
"---question---",
|
||||
question,
|
||||
)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=search_filter_construction_prompt,
|
||||
)
|
||||
]
|
||||
llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_FILTER_CONSTRUCTION_TIMEOUT,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=1400,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
|
||||
if last_bracket == -1 or first_bracket == -1:
|
||||
raise ValueError("No valid JSON found in LLM response - no brackets found")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
cleaned_response = cleaned_response.replace("{{", '{"')
|
||||
cleaned_response = cleaned_response.replace("}}", '"}')
|
||||
|
||||
try:
|
||||
|
||||
filter_results = KGFilterConstructionResults.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
filter_results = KGFilterConstructionResults(
|
||||
global_entity_filters=[],
|
||||
global_relationship_filters=[],
|
||||
local_entity_filters=[],
|
||||
source_document_filters=[],
|
||||
structure=[],
|
||||
)
|
||||
|
||||
logger.info(f"B1 - filter_results: {filter_results}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in construct_deep_search_filters: {e}")
|
||||
filter_results = KGFilterConstructionResults(
|
||||
global_entity_filters=[],
|
||||
global_relationship_filters=[],
|
||||
local_entity_filters=[],
|
||||
source_document_filters=[],
|
||||
structure=[],
|
||||
)
|
||||
|
||||
div_con_structure = filter_results.structure
|
||||
|
||||
logger.info(f"div_con_structure: {div_con_structure}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
double_grounded_entity_types = get_entity_types_with_grounded_source_name(
|
||||
db_session
|
||||
)
|
||||
|
||||
source_division = False
|
||||
|
||||
if div_con_structure:
|
||||
for entity_type in double_grounded_entity_types:
|
||||
if entity_type.grounded_source_name.lower() in div_con_structure[0].lower():
|
||||
source_division = True
|
||||
break
|
||||
|
||||
return DeepSearchFilterUpdate(
|
||||
vespa_filter_results=filter_results,
|
||||
div_con_entities=div_con_structure,
|
||||
source_division=source_division,
|
||||
global_entity_filters=[
|
||||
make_entity_id(global_filter, "*")
|
||||
for global_filter in filter_results.global_entity_filters
|
||||
],
|
||||
global_relationship_filters=filter_results.global_relationship_filters,
|
||||
local_entity_filters=filter_results.local_entity_filters,
|
||||
source_filters=filter_results.source_document_filters,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="construct deep search filters",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[],
|
||||
)
|
||||
@@ -0,0 +1,178 @@
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
get_doc_information_for_entity,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import write_custom_event
|
||||
from onyx.agents.agent_search.kb_search.ops import research
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.configs.kg_configs import KG_OBJECT_SOURCE_RESEARCH_TIMEOUT
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.kg_prompts import KG_OBJECT_SOURCE_RESEARCH_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def process_individual_deep_search(
|
||||
state: ResearchObjectInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ResearchObjectUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
_KG_STEP_NR = 4
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = state.broken_down_question
|
||||
source_division = state.source_division
|
||||
source_filters = state.source_entity_filters
|
||||
|
||||
object = state.entity.replace("::", ":: ").lower()
|
||||
|
||||
object_id = object.split("::")[1].strip()
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError("search_tool is not provided")
|
||||
|
||||
research_nr = state.research_nr
|
||||
|
||||
extended_question = f"{question} in regards to {object}"
|
||||
if source_division:
|
||||
extended_question = question
|
||||
|
||||
raw_kg_entity_filters = copy.deepcopy(
|
||||
list(set((state.vespa_filter_results.global_entity_filters + [state.entity])))
|
||||
)
|
||||
|
||||
kg_entity_filters = []
|
||||
for raw_kg_entity_filter in raw_kg_entity_filters:
|
||||
if "::" not in raw_kg_entity_filter:
|
||||
raw_kg_entity_filter += "::*"
|
||||
kg_entity_filters.append(raw_kg_entity_filter)
|
||||
|
||||
kg_relationship_filters = copy.deepcopy(
|
||||
state.vespa_filter_results.global_relationship_filters
|
||||
)
|
||||
|
||||
# if this is a single-object analysis and object is a source,
|
||||
# drop the other filters as they already were included
|
||||
# in the KG query that led to this analysis
|
||||
|
||||
if source_division and source_filters:
|
||||
for source_filter in source_filters:
|
||||
if object_id.lower() == source_filter.lower():
|
||||
source_filters = [source_filter]
|
||||
kg_relationship_filters = []
|
||||
kg_entity_filters = []
|
||||
break
|
||||
|
||||
logger.info("Research for object: " + object)
|
||||
logger.info(f"kg_entity_filters: {kg_entity_filters}")
|
||||
logger.info(f"kg_relationship_filters: {kg_relationship_filters}")
|
||||
|
||||
# Step 4 - stream out the research query
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=f"{get_doc_information_for_entity(object).semantic_entity_name}",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=research_nr + 1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
retrieved_docs = research(
|
||||
question=extended_question,
|
||||
kg_entities=kg_entity_filters,
|
||||
kg_relationships=kg_relationship_filters,
|
||||
kg_sources=source_filters,
|
||||
search_tool=search_tool,
|
||||
)
|
||||
|
||||
document_texts_list = []
|
||||
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs):
|
||||
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
|
||||
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
# Built prompt
|
||||
|
||||
kg_object_source_research_prompt = KG_OBJECT_SOURCE_RESEARCH_PROMPT.format(
|
||||
question=extended_question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
|
||||
# Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=kg_object_source_research_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
object_research_results = str(llm_response.content).replace("```json\n", "")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in research_object_source: {e}")
|
||||
|
||||
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
|
||||
|
||||
return ResearchObjectUpdate(
|
||||
research_object_results=[
|
||||
{
|
||||
"object": object.replace("::", ":: ").capitalize(),
|
||||
"results": object_research_results,
|
||||
}
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="process individual deep search",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[],
|
||||
)
|
||||
@@ -0,0 +1,177 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import write_custom_event
|
||||
from onyx.agents.agent_search.kb_search.ops import research
|
||||
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
get_answer_generation_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.configs.kg_configs import KG_FILTERED_SEARCH_TIMEOUT
|
||||
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.kg_prompts import KG_SEARCH_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def filtered_search(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ConsolidatedResearchUpdate:
|
||||
"""
|
||||
LangGraph node to do a filtered search.
|
||||
"""
|
||||
_KG_STEP_NR = 4
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError("search_tool is not provided")
|
||||
|
||||
if not state.vespa_filter_results:
|
||||
raise ValueError("vespa_filter_results is not provided")
|
||||
raw_kg_entity_filters = list(
|
||||
set((state.vespa_filter_results.global_entity_filters))
|
||||
)
|
||||
|
||||
kg_entity_filters = []
|
||||
for raw_kg_entity_filter in raw_kg_entity_filters:
|
||||
if "::" not in raw_kg_entity_filter:
|
||||
raw_kg_entity_filter += "::*"
|
||||
kg_entity_filters.append(raw_kg_entity_filter)
|
||||
|
||||
kg_relationship_filters = state.vespa_filter_results.global_relationship_filters
|
||||
|
||||
logger.info("Starting filtered search")
|
||||
logger.debug(f"kg_entity_filters: {kg_entity_filters}")
|
||||
logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
|
||||
|
||||
# Step 4 - stream out the research query
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query="Conduct a filtered search",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
retrieved_docs = cast(
|
||||
list[InferenceSection],
|
||||
research(
|
||||
question=question,
|
||||
kg_entities=kg_entity_filters,
|
||||
kg_relationships=kg_relationship_filters,
|
||||
kg_sources=None,
|
||||
search_tool=search_tool,
|
||||
inference_sections_only=True,
|
||||
),
|
||||
)
|
||||
|
||||
answer_generation_documents = get_answer_generation_documents(
|
||||
relevant_docs=retrieved_docs,
|
||||
context_documents=retrieved_docs,
|
||||
original_question_docs=retrieved_docs,
|
||||
max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
|
||||
)
|
||||
|
||||
document_texts_list = []
|
||||
|
||||
for doc_num, retrieved_doc in enumerate(
|
||||
answer_generation_documents.context_documents
|
||||
):
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
# Built prompt
|
||||
|
||||
datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
kg_object_source_research_prompt = KG_SEARCH_PROMPT.format(
|
||||
question=question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
|
||||
# Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=kg_object_source_research_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
llm = primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
KG_FILTERED_SEARCH_TIMEOUT,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
filtered_search_answer = str(llm_response.content).replace("```json\n", "")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in filtered_search: {e}")
|
||||
|
||||
step_answer = "Filtered search is complete."
|
||||
|
||||
stream_write_step_answer_explicit(
|
||||
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
|
||||
|
||||
return ConsolidatedResearchUpdate(
|
||||
consolidated_research_object_results_str=filtered_search_answer,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="filtered search",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR,
|
||||
step_answer=step_answer,
|
||||
verified_reranked_documents=retrieved_docs,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,64 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def consolidate_individual_deep_search(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ConsolidatedResearchUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
_KG_STEP_NR = 4
|
||||
node_start_time = datetime.now()
|
||||
|
||||
research_object_results = state.research_object_results
|
||||
|
||||
consolidated_research_object_results_str = "\n".join(
|
||||
[f"{x['object']}: {x['results']}" for x in research_object_results]
|
||||
)
|
||||
|
||||
consolidated_research_object_results_str = rename_entities_in_answer(
|
||||
consolidated_research_object_results_str
|
||||
)
|
||||
|
||||
step_answer = "All research is complete. Consolidating results..."
|
||||
|
||||
stream_write_step_answer_explicit(
|
||||
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
|
||||
|
||||
return ConsolidatedResearchUpdate(
|
||||
consolidated_research_object_results_str=consolidated_research_object_results_str,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="consolidate individual deep search",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR, step_answer=step_answer
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,122 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.db.document import get_base_llm_doc_information
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_formated_source_reference_results(
|
||||
source_document_results: list[str] | None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Generate reference results from the query results data string.
|
||||
"""
|
||||
|
||||
if source_document_results is None:
|
||||
return None
|
||||
|
||||
# get all entities that correspond to an Onyx document
|
||||
document_ids = source_document_results
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
llm_doc_information_results = get_base_llm_doc_information(
|
||||
session, document_ids
|
||||
)
|
||||
|
||||
if len(llm_doc_information_results) == 0:
|
||||
return ""
|
||||
|
||||
return (
|
||||
f"\n \n Here are {len(llm_doc_information_results)} supporting documents or examples: \n \n "
|
||||
+ " \n \n ".join(llm_doc_information_results)
|
||||
)
|
||||
|
||||
|
||||
def process_kg_only_answers(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ResultsDataUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
_KG_STEP_NR = 4
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
query_results = state.sql_query_results
|
||||
source_document_results = state.source_document_results
|
||||
|
||||
# we use this stream write explicitly
|
||||
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query="Formatted References",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
query_results_list = []
|
||||
|
||||
if query_results:
|
||||
for query_result in query_results:
|
||||
query_results_list.append(
|
||||
str(query_result).replace("::", ":: ").capitalize()
|
||||
)
|
||||
else:
|
||||
raise ValueError("No query results were found")
|
||||
|
||||
query_results_data_str = "\n".join(query_results_list)
|
||||
|
||||
source_reference_result_str = _get_formated_source_reference_results(
|
||||
source_document_results
|
||||
)
|
||||
|
||||
## STEP 4 - same components as Step 1
|
||||
|
||||
step_answer = (
|
||||
"No further research is needed, the answer is derived from the knowledge graph."
|
||||
)
|
||||
|
||||
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
|
||||
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
return ResultsDataUpdate(
|
||||
query_results_data_str=query_results_data_str,
|
||||
individualized_query_results_data_str="",
|
||||
reference_results_str=source_reference_result_str,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="kg query results data processing",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR, step_answer=step_answer
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,282 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_close_main_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_close_steps
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_main_answer_token,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.ops import research
|
||||
from onyx.agents.agent_search.kb_search.states import MainOutput
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
get_answer_generation_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.configs.kg_configs import KG_ANSWER_GENERATION_TIMEOUT
|
||||
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_EXAMPLES_PROMPT
|
||||
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT
|
||||
from onyx.tools.tool_implementations.search.search_tool import IndexFilters
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchQueryInfo
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _stream_augmentations(
|
||||
llm_tokenizer: BaseTokenizer, streaming_text: str, writer: StreamWriter
|
||||
) -> None:
|
||||
|
||||
# Tokenize and stream the reference results
|
||||
tokens = llm_tokenizer.tokenize(streaming_text)
|
||||
for token in tokens:
|
||||
|
||||
stream_write_main_answer_token(writer, token)
|
||||
|
||||
|
||||
def generate_answer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> MainOutput:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
if search_tool is None:
|
||||
raise ValueError("Search tool is not set")
|
||||
|
||||
# Close out previous streams of steps
|
||||
|
||||
# DECLARE STEPS DONE
|
||||
|
||||
stream_write_close_steps(writer)
|
||||
|
||||
## MAIN ANSWER
|
||||
|
||||
# identify whether documents have already been retrieved
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
for step_result in state.step_results:
|
||||
retrieved_docs += step_result.verified_reranked_documents
|
||||
|
||||
# if still needed, get a search done and send the results to the UI
|
||||
|
||||
if not retrieved_docs and state.source_document_results:
|
||||
assert graph_config.tooling.search_tool is not None
|
||||
retrieved_docs = cast(
|
||||
list[InferenceSection],
|
||||
research(
|
||||
question=question,
|
||||
kg_entities=[],
|
||||
kg_relationships=[],
|
||||
kg_sources=state.source_document_results[
|
||||
:KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
],
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
kg_chunk_id_zero_only=True,
|
||||
inference_sections_only=True,
|
||||
),
|
||||
)
|
||||
|
||||
answer_generation_documents = get_answer_generation_documents(
|
||||
relevant_docs=retrieved_docs,
|
||||
context_documents=retrieved_docs,
|
||||
original_question_docs=retrieved_docs,
|
||||
max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
|
||||
)
|
||||
|
||||
relevance_list = relevance_from_docs(
|
||||
answer_generation_documents.streaming_documents
|
||||
)
|
||||
|
||||
assert graph_config.tooling.search_tool is not None
|
||||
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=SearchQueryInfo(
|
||||
predicted_search=SearchType.KEYWORD,
|
||||
# acl here is empty, because the searach alrady happened and
|
||||
# we are streaming out the results.
|
||||
final_filters=IndexFilters(access_control_list=[]),
|
||||
recency_bias_multiplier=1.0,
|
||||
),
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
):
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=0,
|
||||
level_question_num=0, # 0, 0 is the base question
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
# continue with the answer generation
|
||||
|
||||
output_format = (
|
||||
state.output_format.value
|
||||
if state.output_format
|
||||
else "<you be the judge how to best present the data>"
|
||||
)
|
||||
|
||||
# if deep path was taken:
|
||||
|
||||
consolidated_research_object_results_str = (
|
||||
state.consolidated_research_object_results_str
|
||||
)
|
||||
# reference_results_str = (
|
||||
# state.reference_results_str
|
||||
# ) # will not be part of LLM. Manually added to the answer
|
||||
|
||||
# if simple path was taken:
|
||||
introductory_answer = state.query_results_data_str # from simple answer path only
|
||||
if consolidated_research_object_results_str:
|
||||
research_results = consolidated_research_object_results_str
|
||||
else:
|
||||
research_results = ""
|
||||
|
||||
if introductory_answer:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
|
||||
.replace(
|
||||
"---introductory_answer---",
|
||||
rename_entities_in_answer(introductory_answer),
|
||||
)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
)
|
||||
elif research_results and consolidated_research_object_results_str:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
|
||||
.replace(
|
||||
"---introductory_answer---",
|
||||
rename_entities_in_answer(consolidated_research_object_results_str),
|
||||
)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
)
|
||||
elif research_results and not consolidated_research_object_results_str:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT.replace("---question---", question)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
.replace(
|
||||
"---research_results---", rename_entities_in_answer(research_results)
|
||||
)
|
||||
)
|
||||
elif consolidated_research_object_results_str:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
.replace(
|
||||
"---research_results---", rename_entities_in_answer(research_results)
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("No research results or introductory answer provided")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=output_format_prompt,
|
||||
)
|
||||
]
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
|
||||
dispatch_timings: list[float] = []
|
||||
response: list[str] = []
|
||||
|
||||
def stream_answer() -> list[str]:
|
||||
# Get the LLM's tokenizer
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=fast_llm.config.model_name,
|
||||
provider_type=fast_llm.config.model_provider,
|
||||
)
|
||||
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=1000,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
|
||||
# Tokenize the content using the LLM's tokenizer
|
||||
tokens = llm_tokenizer.tokenize(content)
|
||||
for token in tokens:
|
||||
start_stream_token = datetime.now()
|
||||
stream_write_main_answer_token(
|
||||
writer, token, level=0, level_question_num=0
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(token)
|
||||
return response
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
KG_ANSWER_GENERATION_TIMEOUT,
|
||||
stream_answer,
|
||||
)
|
||||
|
||||
# llm_tokenizer = get_tokenizer(
|
||||
# model_name=fast_llm.config.model_name,
|
||||
# provider_type=fast_llm.config.model_provider,
|
||||
# )
|
||||
|
||||
# TODO: the fake streaming should happen in friont-end. Revisit and then
|
||||
# simply stream out here the full text in one.
|
||||
# if reference_results_str:
|
||||
# # Get the LLM's tokenizer
|
||||
|
||||
# _stream_augmentations(llm_tokenizer, reference_results_str, writer)
|
||||
|
||||
# if state.remarks:
|
||||
# _stream_augmentations(llm_tokenizer, "Comments: \n " + "\n".join(state.remarks), writer)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not generate the answer. Error {e}")
|
||||
|
||||
stream_write_close_main_answer(writer)
|
||||
|
||||
return MainOutput(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="query completed",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,58 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import MainOutput
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.chat import log_agent_sub_question_results
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def log_data(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> MainOutput:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
if search_tool is None:
|
||||
raise ValueError("Search tool is not set")
|
||||
|
||||
# commit original db_session
|
||||
|
||||
query_db_session = graph_config.persistence.db_session
|
||||
query_db_session.commit()
|
||||
|
||||
chat_session_id = graph_config.persistence.chat_session_id
|
||||
primary_message_id = graph_config.persistence.message_id
|
||||
sub_question_answer_results = state.step_results
|
||||
|
||||
log_agent_sub_question_results(
|
||||
db_session=query_db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
primary_message_id=primary_message_id,
|
||||
sub_question_answer_results=sub_question_answer_results,
|
||||
)
|
||||
|
||||
return MainOutput(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="query completed",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
65
backend/onyx/agents/agent_search/kb_search/ops.py
Normal file
65
backend/onyx/agents/agent_search/kb_search/ops.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
def research(
|
||||
question: str,
|
||||
search_tool: SearchTool,
|
||||
document_sources: list[DocumentSource] | None = None,
|
||||
time_cutoff: datetime | None = None,
|
||||
kg_entities: list[str] | None = None,
|
||||
kg_relationships: list[str] | None = None,
|
||||
kg_terms: list[str] | None = None,
|
||||
kg_sources: list[str] | None = None,
|
||||
kg_chunk_id_zero_only: bool = False,
|
||||
inference_sections_only: bool = False,
|
||||
) -> list[LlmDoc] | list[InferenceSection]:
|
||||
# new db session to avoid concurrency issues
|
||||
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
retrieved_docs: list[LlmDoc] | list[InferenceSection] = []
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=False,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
document_sources=document_sources,
|
||||
time_cutoff=time_cutoff,
|
||||
kg_entities=kg_entities,
|
||||
kg_relationships=kg_relationships,
|
||||
kg_terms=kg_terms,
|
||||
kg_sources=kg_sources,
|
||||
kg_chunk_id_zero_only=kg_chunk_id_zero_only,
|
||||
),
|
||||
):
|
||||
if (
|
||||
inference_sections_only
|
||||
and tool_response.id == "search_response_summary"
|
||||
):
|
||||
retrieved_docs = tool_response.response.top_sections[
|
||||
:KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
]
|
||||
retrieved_docs = cast(list[InferenceSection], retrieved_docs)
|
||||
break
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
elif tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
retrieved_docs = cast(list[LlmDoc], tool_response.response)[
|
||||
:KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
]
|
||||
break
|
||||
return retrieved_docs
|
||||
169
backend/onyx/agents/agent_search/kb_search/states.py
Normal file
169
backend/onyx/agents/agent_search/kb_search/states.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from enum import Enum
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import SubQuestionAnswerResults
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
|
||||
|
||||
class StepResults(BaseModel):
|
||||
question: str
|
||||
question_id: str
|
||||
answer: str
|
||||
sub_query_retrieval_results: list[QueryRetrievalResult]
|
||||
verified_reranked_documents: list[InferenceSection]
|
||||
context_documents: list[InferenceSection]
|
||||
cited_documents: list[InferenceSection]
|
||||
|
||||
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
step_results: Annotated[list[SubQuestionAnswerResults], add]
|
||||
remarks: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class KGFilterConstructionResults(BaseModel):
|
||||
global_entity_filters: list[str]
|
||||
global_relationship_filters: list[str]
|
||||
local_entity_filters: list[list[str]]
|
||||
source_document_filters: list[str]
|
||||
structure: list[str]
|
||||
|
||||
|
||||
class KGSearchType(Enum):
|
||||
SEARCH = "SEARCH"
|
||||
SQL = "SQL"
|
||||
|
||||
|
||||
class KGAnswerStrategy(Enum):
|
||||
DEEP = "DEEP"
|
||||
SIMPLE = "SIMPLE"
|
||||
|
||||
|
||||
class KGAnswerFormat(Enum):
|
||||
LIST = "LIST"
|
||||
TEXT = "TEXT"
|
||||
|
||||
|
||||
class YesNoEnum(str, Enum):
|
||||
YES = "yes"
|
||||
NO = "no"
|
||||
|
||||
|
||||
class AnalysisUpdate(LoggerUpdate):
|
||||
normalized_core_entities: list[str] = []
|
||||
normalized_core_relationships: list[str] = []
|
||||
entity_normalization_map: dict[str, str] = {}
|
||||
relationship_normalization_map: dict[str, str] = {}
|
||||
query_graph_entities_no_attributes: list[str] = []
|
||||
query_graph_entities_w_attributes: list[str] = []
|
||||
query_graph_relationships: list[str] = []
|
||||
normalized_terms: list[str] = []
|
||||
normalized_time_filter: str | None = None
|
||||
strategy: KGAnswerStrategy | None = None
|
||||
output_format: KGAnswerFormat | None = None
|
||||
broken_down_question: str | None = None
|
||||
divide_and_conquer: YesNoEnum | None = None
|
||||
single_doc_id: str | None = None
|
||||
search_type: KGSearchType | None = None
|
||||
|
||||
|
||||
class SQLSimpleGenerationUpdate(LoggerUpdate):
|
||||
sql_query: str | None = None
|
||||
sql_query_results: list[Dict[Any, Any]] | None = None
|
||||
individualized_sql_query: str | None = None
|
||||
individualized_query_results: list[Dict[Any, Any]] | None = None
|
||||
source_documents_sql: str | None = None
|
||||
source_document_results: list[str] | None = None
|
||||
updated_strategy: KGAnswerStrategy | None = None
|
||||
|
||||
|
||||
class ConsolidatedResearchUpdate(LoggerUpdate):
|
||||
consolidated_research_object_results_str: str | None = None
|
||||
|
||||
|
||||
class DeepSearchFilterUpdate(LoggerUpdate):
|
||||
vespa_filter_results: KGFilterConstructionResults | None = None
|
||||
div_con_entities: list[str] | None = None
|
||||
source_division: bool | None = None
|
||||
global_entity_filters: list[str] | None = None
|
||||
global_relationship_filters: list[str] | None = None
|
||||
local_entity_filters: list[list[str]] | None = None
|
||||
source_filters: list[str] | None = None
|
||||
|
||||
|
||||
class ResearchObjectOutput(LoggerUpdate):
|
||||
research_object_results: Annotated[list[dict[str, Any]], add] = []
|
||||
|
||||
|
||||
class ERTExtractionUpdate(LoggerUpdate):
|
||||
entities_types_str: str = ""
|
||||
relationship_types_str: str = ""
|
||||
extracted_entities_w_attributes: list[str] = []
|
||||
extracted_entities_no_attributes: list[str] = []
|
||||
extracted_relationships: list[str] = []
|
||||
extracted_terms: list[str] = []
|
||||
time_filter: str | None = None
|
||||
kg_doc_temp_view_name: str | None = None
|
||||
kg_rel_temp_view_name: str | None = None
|
||||
|
||||
|
||||
class ResultsDataUpdate(LoggerUpdate):
|
||||
query_results_data_str: str | None = None
|
||||
individualized_query_results_data_str: str | None = None
|
||||
reference_results_str: str | None = None
|
||||
|
||||
|
||||
class ResearchObjectUpdate(LoggerUpdate):
|
||||
research_object_results: Annotated[list[dict[str, Any]], add] = []
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
ERTExtractionUpdate,
|
||||
AnalysisUpdate,
|
||||
SQLSimpleGenerationUpdate,
|
||||
ResultsDataUpdate,
|
||||
ResearchObjectOutput,
|
||||
DeepSearchFilterUpdate,
|
||||
ResearchObjectUpdate,
|
||||
ConsolidatedResearchUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
|
||||
|
||||
class ResearchObjectInput(LoggerUpdate):
|
||||
research_nr: int
|
||||
entity: str
|
||||
broken_down_question: str
|
||||
vespa_filter_results: KGFilterConstructionResults
|
||||
source_division: bool | None
|
||||
source_entity_filters: list[str] | None
|
||||
@@ -0,0 +1,29 @@
|
||||
from onyx.agents.agent_search.kb_search.models import KGSteps
|
||||
|
||||
STEP_DESCRIPTIONS: dict[int, KGSteps] = {
|
||||
1: KGSteps(
|
||||
description="Analyzing the question...",
|
||||
activities=[
|
||||
"Entities in Query",
|
||||
"Relationships in Query",
|
||||
"Terms in Query",
|
||||
"Time Filters",
|
||||
],
|
||||
),
|
||||
2: KGSteps(
|
||||
description="Planning the response approach...",
|
||||
activities=["Query Execution Strategy", "Answer Format"],
|
||||
),
|
||||
3: KGSteps(
|
||||
description="Querying the Knowledge Graph...",
|
||||
activities=[
|
||||
"Knowledge Graph Query",
|
||||
"Knowledge Graph Query Results",
|
||||
"Query for Source Documents",
|
||||
"Source Documents",
|
||||
],
|
||||
),
|
||||
4: KGSteps(
|
||||
description="Conducting further research on source documents...", activities=[]
|
||||
),
|
||||
}
|
||||
@@ -8,6 +8,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.db.models import Persona
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.kg.models import KGConfigSettings
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
@@ -70,6 +71,7 @@ class GraphSearchConfig(BaseModel):
|
||||
allow_refinement: bool = True
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
allow_agent_reranking: bool = False
|
||||
kg_config_settings: KGConfigSettings = KGConfigSettings()
|
||||
|
||||
|
||||
class GraphConfig(BaseModel):
|
||||
|
||||
@@ -18,6 +18,8 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
MainInput as MainInput,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
@@ -88,7 +90,7 @@ def _parse_agent_event(
|
||||
def manage_sync_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
graph_input: BasicInput | MainInput | DCMainInput,
|
||||
graph_input: BasicInput | MainInput | DCMainInput | KBMainInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
message_id = config.persistence.message_id if config.persistence else None
|
||||
for event in compiled_graph.stream(
|
||||
@@ -102,7 +104,7 @@ def manage_sync_streaming(
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
input: BasicInput | MainInput | DCMainInput,
|
||||
input: BasicInput | MainInput | DCMainInput | KBMainInput,
|
||||
) -> AnswerStream:
|
||||
|
||||
for event in manage_sync_streaming(
|
||||
@@ -147,6 +149,21 @@ def run_basic_graph(
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_kb_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
graph = kb_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = KBMainInput(log_messages=[])
|
||||
|
||||
yield ToolCallKickoff(
|
||||
tool_name="agent_search_0",
|
||||
tool_args={"query": config.inputs.prompt_builder.raw_user_query},
|
||||
)
|
||||
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_dc_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
|
||||
@@ -159,3 +159,8 @@ BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
class QueryExpansionType(Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
|
||||
|
||||
class ReferenceResults(BaseModel):
|
||||
citations: list[str]
|
||||
general_entities: list[str]
|
||||
|
||||
105
backend/onyx/background/celery/apps/kg_processing.py
Normal file
105
backend/onyx/background/celery/apps/kg_processing.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.kg_processing")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME)
|
||||
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None:
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
]
|
||||
)
|
||||
@@ -298,5 +298,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.user_file_folder_sync",
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
]
|
||||
)
|
||||
|
||||
21
backend/onyx/background/celery/configs/kg_processing.py
Normal file
21
backend/onyx/background/celery/configs/kg_processing.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_KG_PROCESSING_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_KG_PROCESSING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -28,6 +28,24 @@ beat_task_templates: list[dict] = []
|
||||
|
||||
beat_task_templates.extend(
|
||||
[
|
||||
{
|
||||
"name": "check-for-kg-processing",
|
||||
"task": OnyxCeleryTask.CHECK_KG_PROCESSING,
|
||||
"schedule": timedelta(seconds=60),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-kg-processing-clustering-only",
|
||||
"task": OnyxCeleryTask.CHECK_KG_PROCESSING_CLUSTERING_ONLY,
|
||||
"schedule": timedelta(seconds=600),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
|
||||
321
backend/onyx/background/celery/tasks/kg_processing/tasks.py
Normal file
321
backend/onyx/background/celery/tasks/kg_processing/tasks.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import time
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.document import check_for_documents_needing_kg_clustering
|
||||
from onyx.db.document import check_for_documents_needing_kg_processing
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.kg_config import get_kg_processing_in_progress_status
|
||||
from onyx.db.kg_config import KGProcessingType
|
||||
from onyx.db.kg_config import set_kg_processing_in_progress_status
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.kg.clustering.clustering import kg_clustering
|
||||
from onyx.kg.extractions.extraction_processing import kg_extraction
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_KG_PROCESSING,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_kg_processing(self: Task, *, tenant_id: str) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
|
||||
time_start = time.monotonic()
|
||||
task_logger.warning("check_for_kg_processing - Starting")
|
||||
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
redis_client = get_redis_client()
|
||||
get_redis_replica_client()
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.KG_PROCESSING_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
locked = True
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
kg_config = get_kg_config_settings(db_session)
|
||||
|
||||
if not kg_config.KG_ENABLED:
|
||||
|
||||
return None
|
||||
|
||||
kg_coverage_start = kg_config.KG_COVERAGE_START
|
||||
kg_max_coverage_days = kg_config.KG_MAX_COVERAGE_DAYS
|
||||
|
||||
kg_extraction_in_progress = kg_config.KG_EXTRACTION_IN_PROGRESS
|
||||
kg_clustering_in_progress = kg_config.KG_CLUSTERING_IN_PROGRESS
|
||||
|
||||
if kg_extraction_in_progress or kg_clustering_in_progress:
|
||||
task_logger.info(
|
||||
f"KG processing already in progress for tenant {tenant_id}, skipping"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
documents_needing_kg_processing = check_for_documents_needing_kg_processing(
|
||||
db_session, kg_coverage_start, kg_max_coverage_days
|
||||
)
|
||||
|
||||
if not documents_needing_kg_processing:
|
||||
task_logger.info(
|
||||
f"No documents needing KG processing for tenant {tenant_id}, skipping"
|
||||
)
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"Found documents needing KG processing for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.KG_PROCESSING,
|
||||
kwargs={
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.KG_PROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during kg processing check")
|
||||
finally:
|
||||
if locked:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_kg_processing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
return tasks_created
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_KG_PROCESSING_CLUSTERING_ONLY,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_kg_processing_clustering_only(
|
||||
self: Task, *, tenant_id: str
|
||||
) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
|
||||
time_start = time.monotonic()
|
||||
task_logger.warning("check_for_kg_processing_clustering_only - Starting")
|
||||
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
redis_client = get_redis_client()
|
||||
get_redis_replica_client()
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.KG_PROCESSING_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
locked = True
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
kg_clustering_in_progress = get_kg_processing_in_progress_status(
|
||||
db_session, processing_type=KGProcessingType.CLUSTERING
|
||||
)
|
||||
documents_needing_kg_clustering = check_for_documents_needing_kg_clustering(
|
||||
db_session
|
||||
)
|
||||
|
||||
if kg_clustering_in_progress:
|
||||
task_logger.info(
|
||||
f"KG clustering already in progress for tenant {tenant_id}, skipping"
|
||||
)
|
||||
return None
|
||||
elif not documents_needing_kg_clustering:
|
||||
task_logger.info(
|
||||
f"No documents needing KG clustering for tenant {tenant_id}, skipping"
|
||||
)
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"Found documents needing KG processing for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.KG_CLUSTERING_ONLY,
|
||||
kwargs={
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.KG_PROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during kg clustering-only check")
|
||||
finally:
|
||||
if locked:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_kg_processing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
return tasks_created
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.KG_PROCESSING,
|
||||
soft_time_limit=1000,
|
||||
bind=True,
|
||||
)
|
||||
def kg_processing(self: Task, *, tenant_id: str) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
|
||||
time.monotonic()
|
||||
task_logger.warning(f"check_for_kg_processing - Starting for tenant {tenant_id}")
|
||||
|
||||
task_logger.debug("Starting kg processing task!")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
index_str = search_settings.index_name
|
||||
|
||||
set_kg_processing_in_progress_status(
|
||||
db_session, processing_type=KGProcessingType.EXTRACTION, in_progress=True
|
||||
)
|
||||
set_kg_processing_in_progress_status(
|
||||
db_session, processing_type=KGProcessingType.CLUSTERING, in_progress=True
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
task_logger.info(f"KG processing set to in progress for tenant {tenant_id}")
|
||||
|
||||
try:
|
||||
kg_extraction(
|
||||
tenant_id=tenant_id, index_name=index_str, processing_chunk_batch_size=8
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Error during kg extraction: {e}")
|
||||
finally:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
set_kg_processing_in_progress_status(
|
||||
db_session,
|
||||
processing_type=KGProcessingType.EXTRACTION,
|
||||
in_progress=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
task_logger.debug("Completed kg extraction task. Moving to clustering")
|
||||
|
||||
try:
|
||||
kg_clustering(
|
||||
tenant_id=tenant_id, index_name=index_str, processing_chunk_batch_size=8
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Error during kg clustering: {e}")
|
||||
finally:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
set_kg_processing_in_progress_status(
|
||||
db_session,
|
||||
processing_type=KGProcessingType.CLUSTERING,
|
||||
in_progress=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
task_logger.debug("Completed kg clustering task!")
|
||||
|
||||
task_logger.debug("Completed kg clustering task!")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.KG_CLUSTERING_ONLY,
|
||||
soft_time_limit=1000,
|
||||
bind=True,
|
||||
)
|
||||
def kg_clustering_only(self: Task, *, tenant_id: str) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
|
||||
time.monotonic()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
index_str = search_settings.index_name
|
||||
|
||||
set_kg_processing_in_progress_status(
|
||||
db_session, processing_type=KGProcessingType.CLUSTERING, in_progress=True
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
task_logger.info(f"KG processing set to in progress for tenant {tenant_id}")
|
||||
|
||||
task_logger.debug("Starting kg clustering-only task!")
|
||||
|
||||
try:
|
||||
kg_clustering(
|
||||
tenant_id=tenant_id, index_name=index_str, processing_chunk_batch_size=8
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Error during kg clustering: {e}")
|
||||
finally:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
set_kg_processing_in_progress_status(
|
||||
db_session,
|
||||
processing_type=KGProcessingType.CLUSTERING,
|
||||
in_progress=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
task_logger.debug("Completed kg clustering task!")
|
||||
|
||||
task_logger.debug("Completed kg clustering task!")
|
||||
|
||||
return None
|
||||
@@ -23,6 +23,7 @@ from onyx.db.document import mark_document_as_modified
|
||||
from onyx.db.document import mark_document_as_synced
|
||||
from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.relationships import delete_document_references_from_kg
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
@@ -119,6 +120,11 @@ def document_by_cc_pair_cleanup_task(
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
delete_document_references_from_kg(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Factory stub for running celery worker for knowledge graph processing.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from onyx.background.celery.apps.kg_processing import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -12,6 +12,7 @@ from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.run_graph import run_agent_search_graph
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import run_dc_graph
|
||||
from onyx.agents.agent_search.run_graph import run_kb_graph
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
@@ -24,8 +25,10 @@ from onyx.chat.models import SubQuestionKey
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import AGENT_ALLOW_REFINEMENT
|
||||
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
from onyx.configs.chat_configs import USE_DIV_CON_AGENT
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.models import Persona
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -120,6 +123,7 @@ class Answer:
|
||||
allow_refinement=AGENT_ALLOW_REFINEMENT,
|
||||
allow_agent_reranking=allow_agent_reranking,
|
||||
perform_initial_search_decomposition=INITIAL_SEARCH_DECOMPOSITION_ENABLED,
|
||||
kg_config_settings=get_kg_config_settings(db_session),
|
||||
)
|
||||
self.graph_config = GraphConfig(
|
||||
inputs=self.graph_inputs,
|
||||
@@ -134,10 +138,17 @@ class Answer:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
|
||||
if self.graph_config.behavior.use_agentic_search:
|
||||
if self.graph_config.behavior.use_agentic_search and (
|
||||
self.graph_config.inputs.persona
|
||||
and self.graph_config.behavior.kg_config_settings.KG_ENABLED
|
||||
and self.graph_config.inputs.persona.name.startswith("KG Dev")
|
||||
):
|
||||
run_langgraph = run_kb_graph
|
||||
elif self.graph_config.behavior.use_agentic_search:
|
||||
run_langgraph = run_agent_search_graph
|
||||
elif (
|
||||
self.graph_config.inputs.persona
|
||||
and USE_DIV_CON_AGENT
|
||||
and self.graph_config.inputs.persona.description.startswith(
|
||||
"DivCon Beta Agent"
|
||||
)
|
||||
|
||||
@@ -79,6 +79,7 @@ from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from onyx.db.chat import update_chat_session_updated_at_timestamp
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.milestone import check_multi_assistant_milestone
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
from onyx.db.milestone import update_user_assistant_milestone
|
||||
@@ -96,6 +97,15 @@ from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_all_chat_files
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.kg.clustering.clustering import kg_clustering
|
||||
from onyx.kg.configuration import populate_default_account_employee_definitions
|
||||
from onyx.kg.configuration import populate_default_grounded_entity_types
|
||||
from onyx.kg.extractions.extraction_processing import kg_extraction
|
||||
from onyx.kg.resets.reset_extractions import reset_extraction_kg_index
|
||||
from onyx.kg.resets.reset_index import reset_full_kg_index
|
||||
from onyx.kg.resets.reset_normalizations import reset_normalization_kg_index
|
||||
from onyx.kg.resets.reset_source import reset_source_kg_index
|
||||
from onyx.kg.resets.reset_vespa import reset_vespa_kg_index
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
@@ -560,6 +570,57 @@ def stream_chat_message_objects(
|
||||
|
||||
llm: LLM
|
||||
|
||||
kg_config_settings = get_kg_config_settings(db_session)
|
||||
|
||||
if kg_config_settings.KG_ENABLED:
|
||||
|
||||
# Temporarily, until we have a draft UI for the KG Operations/Management
|
||||
|
||||
# get Vespa index
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
index_str = search_settings.index_name
|
||||
|
||||
if new_msg_req.message == "kg_e":
|
||||
kg_extraction(tenant_id, index_str)
|
||||
raise Exception("Extractions done")
|
||||
|
||||
elif new_msg_req.message == "kg_c":
|
||||
kg_clustering(tenant_id, index_str)
|
||||
raise Exception("Clustering done")
|
||||
|
||||
elif new_msg_req.message == "kg":
|
||||
reset_vespa_kg_index(tenant_id, index_str)
|
||||
reset_full_kg_index()
|
||||
kg_extraction(tenant_id, index_str)
|
||||
kg_clustering(tenant_id, index_str)
|
||||
raise Exception("Full KG index reset done")
|
||||
|
||||
elif new_msg_req.message == "kg_rs_full":
|
||||
reset_full_kg_index()
|
||||
raise Exception("Full KG index reset done")
|
||||
|
||||
elif new_msg_req.message == "kg_rs_extraction":
|
||||
reset_extraction_kg_index()
|
||||
raise Exception("Extraction KG index reset done")
|
||||
|
||||
elif new_msg_req.message == "kg_rs_normalization":
|
||||
reset_normalization_kg_index()
|
||||
raise Exception("Normalization KG index reset done")
|
||||
|
||||
elif new_msg_req.message.startswith("kg_rs_source:"):
|
||||
source_name = new_msg_req.message.split(":")[1].strip()
|
||||
reset_source_kg_index(source_name, tenant_id, index_str)
|
||||
raise Exception(f"KG index reset for source {source_name} done")
|
||||
|
||||
elif new_msg_req.message == "kg_rs_vespa":
|
||||
reset_vespa_kg_index(tenant_id, index_str)
|
||||
raise Exception("Vespa KG index reset done")
|
||||
|
||||
elif new_msg_req.message == "kg_setup":
|
||||
populate_default_grounded_entity_types()
|
||||
populate_default_account_employee_definitions()
|
||||
raise Exception("KG setup done")
|
||||
|
||||
try:
|
||||
# Move these variables inside the try block
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
@@ -184,6 +184,13 @@ POSTGRES_API_SERVER_POOL_OVERFLOW = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10
|
||||
)
|
||||
|
||||
POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE") or 10
|
||||
)
|
||||
POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW") or 5
|
||||
)
|
||||
|
||||
# defaults to False
|
||||
# generally should only be used for
|
||||
POSTGRES_USE_NULL_POOL = os.environ.get("POSTGRES_USE_NULL_POOL", "").lower() == "true"
|
||||
@@ -309,6 +316,11 @@ try:
|
||||
except ValueError:
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
|
||||
|
||||
|
||||
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
|
||||
)
|
||||
|
||||
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
|
||||
VESPA_SYNC_MAX_TASKS = 1024
|
||||
|
||||
@@ -746,3 +758,9 @@ IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
|
||||
DISABLE_AUTO_AUTH_REFRESH = (
|
||||
os.environ.get("DISABLE_AUTO_AUTH_REFRESH", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Knowledge Graph Read Only User Configuration
|
||||
DB_READONLY_USER: str = os.environ.get("DB_READONLY_USER", "db_readonly_user")
|
||||
DB_READONLY_PASSWORD: str = urllib.parse.quote_plus(
|
||||
os.environ.get("DB_READONLY_PASSWORD") or "password"
|
||||
)
|
||||
|
||||
@@ -102,3 +102,5 @@ USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH = (
|
||||
os.environ.get("USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH", "false").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
USE_DIV_CON_AGENT = os.environ.get("USE_DIV_CON_AGENT", "false").lower() == "true"
|
||||
|
||||
@@ -68,6 +68,7 @@ POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
|
||||
@@ -324,6 +325,9 @@ class OnyxCeleryQueues:
|
||||
# Monitoring queue
|
||||
MONITORING = "monitoring"
|
||||
|
||||
# KG processing queue
|
||||
KG_PROCESSING = "kg_processing"
|
||||
|
||||
|
||||
class OnyxRedisLocks:
|
||||
PRIMARY_WORKER = "da_lock:primary_worker"
|
||||
@@ -357,6 +361,9 @@ class OnyxRedisLocks:
|
||||
CLOUD_BEAT_TASK_GENERATOR_LOCK = "da_lock:cloud_beat_task_generator"
|
||||
CLOUD_CHECK_ALEMBIC_BEAT_LOCK = "da_lock:cloud_check_alembic"
|
||||
|
||||
# KG processing
|
||||
KG_PROCESSING_LOCK = "da_lock:kg_processing"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences"
|
||||
@@ -373,6 +380,9 @@ class OnyxRedisSignals:
|
||||
"signal:block_validate_connector_deletion_fences"
|
||||
)
|
||||
|
||||
# KG processing
|
||||
CHECK_KG_PROCESSING_BEAT_LOCK = "da_lock:check_kg_processing_beat"
|
||||
|
||||
|
||||
class OnyxRedisConstants:
|
||||
ACTIVE_FENCES = "active_fences"
|
||||
@@ -456,6 +466,12 @@ class OnyxCeleryTask:
|
||||
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
|
||||
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"
|
||||
|
||||
# KG processing
|
||||
CHECK_KG_PROCESSING = "check_kg_processing"
|
||||
KG_PROCESSING = "kg_processing"
|
||||
KG_CLUSTERING_ONLY = "kg_clustering_only"
|
||||
CHECK_KG_PROCESSING_CLUSTERING_ONLY = "check_kg_processing_clustering_only"
|
||||
|
||||
|
||||
# this needs to correspond to the matching entry in supervisord
|
||||
ONYX_CELERY_BEAT_HEARTBEAT_KEY = "onyx:celery:beat:heartbeat"
|
||||
@@ -468,3 +484,8 @@ if platform.system() == "Darwin":
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
|
||||
else:
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore
|
||||
|
||||
|
||||
class OnyxCallTypes(str, Enum):
|
||||
FIREFLIES = "fireflies"
|
||||
GONG = "gong"
|
||||
|
||||
102
backend/onyx/configs/kg_configs.py
Normal file
102
backend/onyx/configs/kg_configs.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import os
|
||||
|
||||
KG_RESEARCH_NUM_RETRIEVED_DOCS: int = int(
|
||||
os.environ.get("KG_RESEARCH_NUM_RETRIEVED_DOCS", "25")
|
||||
)
|
||||
|
||||
|
||||
KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES: int = int(
|
||||
os.environ.get("KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES", "10")
|
||||
)
|
||||
|
||||
|
||||
KG_ENTITY_EXTRACTION_TIMEOUT: int = int(
|
||||
os.environ.get("KG_ENTITY_EXTRACTION_TIMEOUT", "15")
|
||||
)
|
||||
|
||||
KG_RELATIONSHIP_EXTRACTION_TIMEOUT: int = int(
|
||||
os.environ.get("KG_RELATIONSHIP_EXTRACTION_TIMEOUT", "15")
|
||||
)
|
||||
|
||||
KG_STRATEGY_GENERATION_TIMEOUT: int = int(
|
||||
os.environ.get("KG_STRATEGY_GENERATION_TIMEOUT", "20")
|
||||
)
|
||||
|
||||
KG_SQL_GENERATION_TIMEOUT: int = int(os.environ.get("KG_SQL_GENERATION_TIMEOUT", "25"))
|
||||
|
||||
KG_FILTER_CONSTRUCTION_TIMEOUT: int = int(
|
||||
os.environ.get("KG_FILTER_CONSTRUCTION_TIMEOUT", "15")
|
||||
)
|
||||
|
||||
|
||||
KG_NORMALIZATION_RETRIEVE_ENTITIES_LIMIT: int = int(
|
||||
os.environ.get("KG_NORMALIZATION_RETRIEVE_ENTITIES_LIMIT", "100")
|
||||
)
|
||||
|
||||
KG_FILTERED_SEARCH_TIMEOUT: int = int(
|
||||
os.environ.get("KG_FILTERED_SEARCH_TIMEOUT", "30")
|
||||
)
|
||||
|
||||
|
||||
KG_OBJECT_SOURCE_RESEARCH_TIMEOUT: int = int(
|
||||
os.environ.get("KG_OBJECT_SOURCE_RESEARCH_TIMEOUT", "30")
|
||||
)
|
||||
|
||||
KG_ANSWER_GENERATION_TIMEOUT: int = int(
|
||||
os.environ.get("KG_ANSWER_GENERATION_TIMEOUT", "30")
|
||||
)
|
||||
|
||||
KG_MAX_DEEP_SEARCH_RESULTS: int = int(
|
||||
os.environ.get("KG_MAX_DEEP_SEARCH_RESULTS", "30")
|
||||
)
|
||||
|
||||
|
||||
KG_DEFAULT_MAX_PARENT_RECURSION_DEPTH: int = int(
|
||||
os.environ.get("KG_DEFAULT_MAX_PARENT_RECURSION_DEPTH", "2")
|
||||
)
|
||||
|
||||
|
||||
_KG_NORMALIZATION_RERANK_UNIGRAM_WEIGHT: float = max(
|
||||
1e-3,
|
||||
min(1, float(os.environ.get("KG_NORMALIZATION_RERANK_UNIGRAM_WEIGHT", "0.25"))),
|
||||
)
|
||||
_KG_NORMALIZATION_RERANK_BIGRAM_WEIGHT: float = max(
|
||||
1e-3,
|
||||
min(1, float(os.environ.get("KG_NORMALIZATION_RERANK_BIGRAM_WEIGHT", "0.25"))),
|
||||
)
|
||||
_KG_NORMALIZATION_RERANK_TRIGRAM_WEIGHT: float = max(
|
||||
1e-3,
|
||||
min(1, float(os.environ.get("KG_NORMALIZATION_RERANK_TRIGRAM_WEIGHT", "0.5"))),
|
||||
)
|
||||
_KG_NORMALIZATION_RERANK_NGRAM_SUMS: float = (
|
||||
_KG_NORMALIZATION_RERANK_UNIGRAM_WEIGHT
|
||||
+ _KG_NORMALIZATION_RERANK_BIGRAM_WEIGHT
|
||||
+ _KG_NORMALIZATION_RERANK_TRIGRAM_WEIGHT
|
||||
)
|
||||
|
||||
KG_NORMALIZATION_RERANK_NGRAM_WEIGHTS: tuple[float, float, float] = (
|
||||
_KG_NORMALIZATION_RERANK_UNIGRAM_WEIGHT / _KG_NORMALIZATION_RERANK_NGRAM_SUMS,
|
||||
_KG_NORMALIZATION_RERANK_BIGRAM_WEIGHT / _KG_NORMALIZATION_RERANK_NGRAM_SUMS,
|
||||
_KG_NORMALIZATION_RERANK_TRIGRAM_WEIGHT / _KG_NORMALIZATION_RERANK_NGRAM_SUMS,
|
||||
)
|
||||
|
||||
|
||||
KG_NORMALIZATION_RERANK_LEVENSHTEIN_WEIGHT: float = max(
|
||||
0,
|
||||
min(1, float(os.environ.get("KG_NORMALIZATION_RERANK_LEVENSHTEIN_WEIGHT", "0.25"))),
|
||||
)
|
||||
|
||||
|
||||
KG_NORMALIZATION_RERANK_THRESHOLD: float = float(
|
||||
os.environ.get("KG_NORMALIZATION_RERANK_THRESHOLD", "0.3")
|
||||
)
|
||||
|
||||
|
||||
KG_CLUSTERING_RETRIEVE_THRESHOLD: float = float(
|
||||
os.environ.get("KG_CLUSTERING_RETRIEVE_THRESHOLD", "0.6")
|
||||
)
|
||||
|
||||
|
||||
KG_CLUSTERING_THRESHOLD: float = float(
|
||||
os.environ.get("KG_CLUSTERING_THRESHOLD", "0.96")
|
||||
)
|
||||
@@ -193,12 +193,19 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
team {
|
||||
name
|
||||
}
|
||||
assignee {
|
||||
email
|
||||
}
|
||||
previousIdentifiers
|
||||
subIssueSortOrder
|
||||
priorityLabel
|
||||
identifier
|
||||
url
|
||||
branchName
|
||||
state {
|
||||
id
|
||||
name
|
||||
}
|
||||
customerTicketCount
|
||||
description
|
||||
comments {
|
||||
@@ -267,7 +274,19 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
title=node["title"],
|
||||
doc_updated_at=time_str_to_utc(node["updatedAt"]),
|
||||
metadata={
|
||||
"team": node["team"]["name"],
|
||||
k: str(v)
|
||||
for k, v in {
|
||||
"team": (node.get("team") or {}).get("name"),
|
||||
"assignee": (node.get("assignee") or {}).get("email"),
|
||||
"state": (node.get("state") or {}).get("name"),
|
||||
"priority": node.get("priority"),
|
||||
"estimate": node.get("estimate"),
|
||||
"started_at": node.get("startedAt"),
|
||||
"completed_at": node.get("completedAt"),
|
||||
"created_at": node.get("createdAt"),
|
||||
"due_date": node.get("dueDate"),
|
||||
}.items()
|
||||
if v is not None
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -134,29 +134,26 @@ def process_jira_issue(
|
||||
|
||||
metadata_dict: dict[str, str | list[str]] = {}
|
||||
people = set()
|
||||
try:
|
||||
creator = best_effort_get_field_from_issue(issue, _FIELD_REPORTER)
|
||||
if basic_expert_info := best_effort_basic_expert_info(creator):
|
||||
people.add(basic_expert_info)
|
||||
metadata_dict[_FIELD_REPORTER] = basic_expert_info.get_semantic_name()
|
||||
if email := basic_expert_info.get_email():
|
||||
metadata_dict[_FIELD_REPORTER_EMAIL] = email
|
||||
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
creator = best_effort_get_field_from_issue(issue, _FIELD_REPORTER)
|
||||
if creator is not None and (
|
||||
basic_expert_info := best_effort_basic_expert_info(creator)
|
||||
):
|
||||
people.add(basic_expert_info)
|
||||
metadata_dict[_FIELD_REPORTER] = basic_expert_info.get_semantic_name()
|
||||
if email := basic_expert_info.get_email():
|
||||
metadata_dict[_FIELD_REPORTER_EMAIL] = email
|
||||
|
||||
try:
|
||||
assignee = best_effort_get_field_from_issue(issue, _FIELD_ASSIGNEE)
|
||||
if basic_expert_info := best_effort_basic_expert_info(assignee):
|
||||
people.add(basic_expert_info)
|
||||
metadata_dict[_FIELD_ASSIGNEE] = basic_expert_info.get_semantic_name()
|
||||
if email := basic_expert_info.get_email():
|
||||
metadata_dict[_FIELD_ASSIGNEE_EMAIL] = email
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
assignee = best_effort_get_field_from_issue(issue, _FIELD_ASSIGNEE)
|
||||
if assignee is not None and (
|
||||
basic_expert_info := best_effort_basic_expert_info(assignee)
|
||||
):
|
||||
people.add(basic_expert_info)
|
||||
metadata_dict[_FIELD_ASSIGNEE] = basic_expert_info.get_semantic_name()
|
||||
if email := basic_expert_info.get_email():
|
||||
metadata_dict[_FIELD_ASSIGNEE_EMAIL] = email
|
||||
|
||||
metadata_dict[_FIELD_KEY] = issue.key
|
||||
if priority := best_effort_get_field_from_issue(issue, _FIELD_PRIORITY):
|
||||
metadata_dict[_FIELD_PRIORITY] = priority.name
|
||||
if status := best_effort_get_field_from_issue(issue, _FIELD_STATUS):
|
||||
@@ -178,20 +175,17 @@ def process_jira_issue(
|
||||
):
|
||||
metadata_dict[_FIELD_RESOLUTION_DATE_KEY] = resolutiondate
|
||||
|
||||
try:
|
||||
parent = best_effort_get_field_from_issue(issue, _FIELD_PARENT)
|
||||
if parent:
|
||||
metadata_dict[_FIELD_PARENT] = parent.key
|
||||
except Exception:
|
||||
# Parent should exist but if not, doesn't matter
|
||||
pass
|
||||
try:
|
||||
project = best_effort_get_field_from_issue(issue, _FIELD_PROJECT)
|
||||
if project:
|
||||
metadata_dict[_FIELD_PROJECT_NAME] = project.name
|
||||
metadata_dict[_FIELD_PROJECT] = project.key
|
||||
except Exception:
|
||||
# Project should exist.
|
||||
|
||||
parent = best_effort_get_field_from_issue(issue, _FIELD_PARENT)
|
||||
if parent is not None:
|
||||
metadata_dict[_FIELD_PARENT] = parent.key
|
||||
|
||||
project = best_effort_get_field_from_issue(issue, _FIELD_PROJECT)
|
||||
if project is not None:
|
||||
metadata_dict[_FIELD_PROJECT_NAME] = project.name
|
||||
metadata_dict[_FIELD_PROJECT] = project.key
|
||||
else:
|
||||
|
||||
logger.error(f"Project should exist but does not for {issue.key}")
|
||||
|
||||
return Document(
|
||||
|
||||
@@ -23,15 +23,19 @@ JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
|
||||
def best_effort_basic_expert_info(obj: Any) -> BasicExpertInfo | None:
|
||||
display_name = None
|
||||
email = None
|
||||
if hasattr(obj, "displayName"):
|
||||
display_name = obj.displayName
|
||||
else:
|
||||
display_name = obj.get("displayName")
|
||||
|
||||
if hasattr(obj, "emailAddress"):
|
||||
email = obj.emailAddress
|
||||
else:
|
||||
email = obj.get("emailAddress")
|
||||
try:
|
||||
if hasattr(obj, "displayName"):
|
||||
display_name = obj.displayName
|
||||
else:
|
||||
display_name = obj.get("displayName")
|
||||
|
||||
if hasattr(obj, "emailAddress"):
|
||||
email = obj.emailAddress
|
||||
else:
|
||||
email = obj.get("emailAddress")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not email and not display_name:
|
||||
return None
|
||||
|
||||
@@ -35,6 +35,28 @@ logger = setup_logger()
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
|
||||
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
|
||||
"Opportunity": {
|
||||
"Account": "account",
|
||||
"FiscalQuarter": "fiscal_quarter",
|
||||
"FiscalYear": "fiscal_year",
|
||||
"IsClosed": "is_closed",
|
||||
"Name": "name",
|
||||
"StageName": "stage_name",
|
||||
"Type": "type",
|
||||
"Amount": "amount",
|
||||
"CloseDate": "close_date",
|
||||
"Probability": "probability",
|
||||
"CreatedDate": "created_date",
|
||||
"LastModifiedDate": "last_modified_date",
|
||||
},
|
||||
"Contact": {
|
||||
"Account": "account",
|
||||
"CreatedDate": "created_date",
|
||||
"LastModifiedDate": "last_modified_date",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
"""Approach outline
|
||||
@@ -170,6 +192,8 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
# Always want to make sure user is grabbed for permissioning purposes
|
||||
all_types.add("User")
|
||||
# Always want to make sure account is grabbed for reference purposes
|
||||
all_types.add("Account")
|
||||
|
||||
logger.info(f"All object types: num={len(all_types)} list={all_types}")
|
||||
|
||||
@@ -282,7 +306,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
f"remaining={len(updated_ids) - docs_processed}"
|
||||
)
|
||||
for parent_id in parent_id_batch:
|
||||
parent_object = sf_db.get_record(parent_id, parent_type)
|
||||
parent_object = sf_db.get_record(
|
||||
parent_id, parent_type, isChild=False
|
||||
)
|
||||
if not parent_object:
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
@@ -294,6 +320,18 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
doc.metadata["object_type"] = parent_type
|
||||
|
||||
# Add default attributes to the metadata
|
||||
for (
|
||||
sf_attribute,
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
|
||||
if sf_attribute in parent_object.data:
|
||||
doc.metadata[canonical_attribute] = parent_object.data[
|
||||
sf_attribute
|
||||
]
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
docs_to_yield.append(doc)
|
||||
|
||||
@@ -172,7 +172,7 @@ def convert_sf_object_to_doc(
|
||||
|
||||
sections = [_extract_section(sf_object, base_url)]
|
||||
for id in sf_db.get_child_ids(sf_object.id):
|
||||
if not (child_object := sf_db.get_record(id)):
|
||||
if not (child_object := sf_db.get_record(id, isChild=True)):
|
||||
continue
|
||||
sections.append(_extract_section(child_object, base_url))
|
||||
|
||||
|
||||
@@ -456,7 +456,7 @@ class OnyxSalesforceSQLite:
|
||||
return result[0]
|
||||
|
||||
def get_record(
|
||||
self, object_id: str, object_type: str | None = None
|
||||
self, object_id: str, object_type: str | None = None, isChild: bool = False
|
||||
) -> SalesforceObject | None:
|
||||
"""Retrieve the record and return it as a SalesforceObject."""
|
||||
if self._conn is None:
|
||||
@@ -469,15 +469,44 @@ class OnyxSalesforceSQLite:
|
||||
|
||||
with self._conn:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT data FROM salesforce_objects WHERE id = ?", (object_id,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
# Get the object data and account data
|
||||
if object_type == "Account" or isChild:
|
||||
cursor.execute(
|
||||
"SELECT data FROM salesforce_objects WHERE id = ?", (object_id,)
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"SELECT pso.data, r.parent_id as parent_id, sso.object_type FROM salesforce_objects pso \
|
||||
LEFT JOIN relationships r on r.child_id = pso.id \
|
||||
LEFT JOIN salesforce_objects sso on r.parent_id = sso.id \
|
||||
WHERE pso.id = ? ",
|
||||
(object_id,),
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
if not result:
|
||||
logger.warning(f"Object ID {object_id} not found")
|
||||
return None
|
||||
|
||||
data = json.loads(result[0])
|
||||
data = json.loads(result[0][0])
|
||||
|
||||
if object_type != "Account":
|
||||
|
||||
# convert any account ids of the relationships back into data fields, with name
|
||||
for row in result:
|
||||
|
||||
# the following skips Account objects.
|
||||
if len(row) < 3:
|
||||
continue
|
||||
|
||||
if row[1] and row[2] and row[2] == "Account":
|
||||
data["AccountId"] = row[1]
|
||||
cursor.execute(
|
||||
"SELECT data FROM salesforce_objects WHERE id = ?",
|
||||
(row[1],),
|
||||
)
|
||||
account_data = json.loads(cursor.fetchone()[0])
|
||||
data["Account"] = account_data.get("Name", "")
|
||||
|
||||
return SalesforceObject(id=object_id, type=object_type, data=data)
|
||||
|
||||
def find_ids_by_type(self, object_type: str) -> list[str]:
|
||||
|
||||
@@ -111,6 +111,11 @@ class BaseFilters(BaseModel):
|
||||
document_set: list[str] | None = None
|
||||
time_cutoff: datetime | None = None
|
||||
tags: list[Tag] | None = None
|
||||
kg_entities: list[str] | None = None
|
||||
kg_relationships: list[str] | None = None
|
||||
kg_terms: list[str] | None = None
|
||||
kg_sources: list[str] | None = None
|
||||
kg_chunk_id_zero_only: bool | None = False
|
||||
|
||||
|
||||
class UserFileFilters(BaseModel):
|
||||
|
||||
@@ -183,6 +183,11 @@ def retrieval_preprocessing(
|
||||
tags=preset_filters.tags, # Tags are never auto-extracted
|
||||
access_control_list=user_acl_filters,
|
||||
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
|
||||
kg_entities=preset_filters.kg_entities,
|
||||
kg_relationships=preset_filters.kg_relationships,
|
||||
kg_terms=preset_filters.kg_terms,
|
||||
kg_sources=preset_filters.kg_sources,
|
||||
kg_chunk_id_zero_only=preset_filters.kg_chunk_id_zero_only,
|
||||
)
|
||||
|
||||
llm_evaluation_type = LLMEvaluationType.BASIC
|
||||
|
||||
@@ -16,6 +16,7 @@ from onyx.db.enums import IndexingMode
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.kg.models import KGConnectorData
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
from onyx.server.documents.models import ObjectCreationIdResponse
|
||||
from onyx.server.models import StatusResponse
|
||||
@@ -334,3 +335,29 @@ def mark_ccpair_with_indexing_trigger(
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_kg_enabled_connectors(db_session: Session) -> list[KGConnectorData]:
|
||||
"""
|
||||
Retrieves a list of connector IDs that have not been KG processed for a given tenant.
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
Returns:
|
||||
list[KGConnectorData]: List of connector IDs with KG extraction enabled but have unprocessed documents
|
||||
"""
|
||||
try:
|
||||
stmt = select(Connector.id, Connector.source, Connector.kg_coverage_days).where(
|
||||
Connector.kg_processing_enabled
|
||||
)
|
||||
result = db_session.execute(stmt)
|
||||
|
||||
connector_results = [
|
||||
KGConnectorData(id=row[0], source=row[1].lower(), kg_coverage_days=row[2])
|
||||
for row in result.fetchall()
|
||||
]
|
||||
|
||||
return connector_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching unprocessed connector IDs: {str(e)}")
|
||||
raise e
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Generator
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import and_
|
||||
@@ -21,23 +22,36 @@ from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import null
|
||||
|
||||
from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.kg_configs import KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES
|
||||
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.entities import delete_from_kg_entities__no_commit
|
||||
from onyx.db.entities import delete_from_kg_entities_extraction_staging__no_commit
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import Document as DbDocument
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import KGEntity
|
||||
from onyx.db.models import KGRelationship
|
||||
from onyx.db.models import User
|
||||
from onyx.db.relationships import delete_from_kg_relationships__no_commit
|
||||
from onyx.db.relationships import (
|
||||
delete_from_kg_relationships_extraction_staging__no_commit,
|
||||
)
|
||||
from onyx.db.tag import delete_document_tags_for_documents__no_commit
|
||||
from onyx.db.utils import model_to_dict
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.kg.models import KGStage
|
||||
from onyx.kg.utils.formatting_utils import split_entity_id
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -392,6 +406,7 @@ def upsert_documents(
|
||||
last_modified=datetime.now(timezone.utc),
|
||||
primary_owners=doc.primary_owners,
|
||||
secondary_owners=doc.secondary_owners,
|
||||
kg_stage=KGStage.NOT_STARTED,
|
||||
)
|
||||
)
|
||||
for doc in seen_documents.values()
|
||||
@@ -605,7 +620,29 @@ def delete_documents_complete__no_commit(
|
||||
) -> None:
|
||||
"""This completely deletes the documents from the db, including all foreign key relationships"""
|
||||
|
||||
# Start by deleting the chunk stats for the documents
|
||||
# Start with the kg references
|
||||
|
||||
delete_from_kg_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
delete_from_kg_entities__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
delete_from_kg_relationships_extraction_staging__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
delete_from_kg_entities_extraction_staging__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
# Continue with deleting the chunk stats for the documents
|
||||
delete_chunk_stats_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
@@ -858,3 +895,374 @@ def fetch_chunk_count_for_document(
|
||||
) -> int | None:
|
||||
stmt = select(DbDocument.chunk_count).where(DbDocument.id == document_id)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_unprocessed_kg_document_batch_for_connector(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
kg_coverage_start: datetime,
|
||||
kg_max_coverage_days: int,
|
||||
batch_size: int = 100,
|
||||
) -> list[DbDocument]:
|
||||
"""
|
||||
Retrieves a batch of documents that have not been processed for knowledge graph extraction.
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
connector_id (int): The ID of the connector to get documents for
|
||||
batch_size (int): The maximum number of documents to retrieve
|
||||
Returns:
|
||||
list[DbDocument]: List of documents that need KG processing
|
||||
"""
|
||||
|
||||
stmt = (
|
||||
select(DbDocument)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DbDocument.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DbDocument.doc_updated_at >= kg_coverage_start,
|
||||
DbDocument.doc_updated_at
|
||||
>= datetime.now() - timedelta(days=kg_max_coverage_days),
|
||||
or_(
|
||||
DbDocument.kg_stage.is_(None),
|
||||
DbDocument.kg_stage == KGStage.NOT_STARTED,
|
||||
DbDocument.doc_updated_at > DbDocument.kg_processing_time,
|
||||
),
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
.order_by(DbDocument.doc_updated_at.desc())
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
documents = db_session.scalars(stmt).all()
|
||||
db_session.flush()
|
||||
|
||||
return list(documents)
|
||||
|
||||
|
||||
def get_kg_extracted_document_ids(db_session: Session) -> list[str]:
|
||||
"""
|
||||
Retrieves all document IDs where kg_stage is EXTRACTED.
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
Returns:
|
||||
list[str]: List of document IDs that have been KG processed
|
||||
"""
|
||||
stmt = select(DbDocument.id).where(DbDocument.kg_stage == KGStage.EXTRACTED)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def update_document_kg_info(
|
||||
db_session: Session, document_id: str, kg_stage: KGStage
|
||||
) -> None:
|
||||
"""Updates the knowledge graph related information for a document.
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
document_id (str): The ID of the document to update
|
||||
kg_stage (KGStage): The stage of the knowledge graph processing for the document
|
||||
Raises:
|
||||
ValueError: If the document with the given ID is not found
|
||||
"""
|
||||
stmt = (
|
||||
update(DbDocument)
|
||||
.where(DbDocument.id == document_id)
|
||||
.values(
|
||||
kg_stage=kg_stage,
|
||||
kg_processing_time=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def update_document_kg_stage(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
kg_stage: KGStage,
|
||||
) -> None:
|
||||
stmt = (
|
||||
update(DbDocument).where(DbDocument.id == document_id).values(kg_stage=kg_stage)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def get_all_kg_extracted_documents_info(
|
||||
db_session: Session,
|
||||
) -> list[str]:
|
||||
"""Retrieves the knowledge graph data for all documents that have been processed.
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
Returns:
|
||||
List[Tuple[str, dict]]: A list of tuples containing:
|
||||
- str: The document ID
|
||||
- dict: The KG data containing 'entities', 'relationships', and 'terms'
|
||||
Only returns documents where kg_stage is EXTRACTED
|
||||
"""
|
||||
stmt = (
|
||||
select(DbDocument.id)
|
||||
.where(DbDocument.kg_stage == KGStage.EXTRACTED)
|
||||
.order_by(DbDocument.id)
|
||||
)
|
||||
|
||||
results = db_session.execute(stmt).all()
|
||||
return [str(doc_id) for doc_id in results]
|
||||
|
||||
|
||||
def get_base_llm_doc_information(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> list[str]:
|
||||
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
|
||||
results = db_session.execute(stmt).all()
|
||||
|
||||
documents = []
|
||||
|
||||
for doc_nr, doc in enumerate(results):
|
||||
bare_doc = doc[0]
|
||||
documents.append(
|
||||
f"""* [{bare_doc.semantic_id}]({bare_doc.link}) ({bare_doc.doc_updated_at})"""
|
||||
)
|
||||
|
||||
return documents[:KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES]
|
||||
|
||||
|
||||
def get_document_updated_at(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> datetime | None:
|
||||
"""Retrieves the doc_updated_at timestamp for a given document ID.
|
||||
Args:
|
||||
document_id (str): The ID of the document to query
|
||||
db_session (Session): The database session to use
|
||||
Returns:
|
||||
Optional[datetime]: The doc_updated_at timestamp if found, None if document doesn't exist
|
||||
"""
|
||||
parts = split_entity_id(document_id)
|
||||
if len(parts) == 2:
|
||||
document_id = parts[1]
|
||||
elif len(parts) > 2:
|
||||
raise ValueError(f"Invalid document ID: {document_id}")
|
||||
|
||||
stmt = select(DbDocument.doc_updated_at).where(DbDocument.id == document_id)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def reset_all_document_kg_stages(db_session: Session) -> int:
|
||||
"""Reset the KG stage of all documents that are not in NOT_STARTED state to NOT_STARTED.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
|
||||
Returns:
|
||||
int: Number of documents that were reset
|
||||
"""
|
||||
stmt = (
|
||||
update(DbDocument)
|
||||
.where(DbDocument.kg_stage != KGStage.NOT_STARTED)
|
||||
.values(kg_stage=KGStage.NOT_STARTED)
|
||||
)
|
||||
result = db_session.execute(stmt)
|
||||
|
||||
# The hasattr check is needed for type checking, even though rowcount
|
||||
# is guaranteed to exist at runtime for UPDATE operations
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0
|
||||
|
||||
|
||||
def update_document_kg_stages(
|
||||
db_session: Session, source_stage: KGStage, target_stage: KGStage
|
||||
) -> int:
|
||||
"""Reset the KG stage only of documents back to NOT_STARTED.
|
||||
Part of reset flow for documents that have been extracted but not clustered.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
|
||||
Returns:
|
||||
int: Number of documents that were reset
|
||||
"""
|
||||
stmt = (
|
||||
update(DbDocument)
|
||||
.where(DbDocument.kg_stage == source_stage)
|
||||
.values(kg_stage=target_stage)
|
||||
)
|
||||
result = db_session.execute(stmt)
|
||||
# The hasattr check is needed for type checking, even though rowcount
|
||||
# is guaranteed to exist at runtime for UPDATE operations
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0
|
||||
|
||||
|
||||
def get_skipped_kg_documents(db_session: Session) -> list[str]:
|
||||
"""
|
||||
Retrieves all document IDs where kg_stage is SKIPPED.
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
Returns:
|
||||
list[str]: List of document IDs that have been skipped in KG processing
|
||||
"""
|
||||
stmt = select(DbDocument.id).where(DbDocument.kg_stage == KGStage.SKIPPED)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def get_kg_doc_info_for_entity_name(
|
||||
db_session: Session, document_id: str, entity_type: str
|
||||
) -> KGEntityDocInfo:
|
||||
"""
|
||||
Get the semantic ID and the link for an entity name.
|
||||
"""
|
||||
|
||||
result = (
|
||||
db_session.query(Document.semantic_id, Document.link)
|
||||
.filter(Document.id == document_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return KGEntityDocInfo(
|
||||
doc_id=None,
|
||||
doc_semantic_id=None,
|
||||
doc_link=None,
|
||||
semantic_entity_name=f"{entity_type}:{document_id}",
|
||||
semantic_linked_entity_name=f"{entity_type}:{document_id}",
|
||||
)
|
||||
|
||||
return KGEntityDocInfo(
|
||||
doc_id=document_id,
|
||||
doc_semantic_id=result[0],
|
||||
doc_link=result[1],
|
||||
semantic_entity_name=f"{entity_type.upper()}:{result[0]}",
|
||||
semantic_linked_entity_name=f"[{entity_type.upper()}:{result[0]}]({result[1]})",
|
||||
)
|
||||
|
||||
|
||||
def check_for_documents_needing_kg_processing(
|
||||
db_session: Session, kg_coverage_start: datetime, kg_max_coverage_days: int
|
||||
) -> bool:
|
||||
"""Check if there are any documents that need KG processing.
|
||||
|
||||
A document needs KG processing if:
|
||||
1. It is associated with a connector that has kg_processing_enabled = true
|
||||
2. AND either:
|
||||
- Its kg_stage is NOT_STARTED or NULL
|
||||
- OR its last_updated timestamp is greater than its kg_processing_time
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
|
||||
Returns:
|
||||
bool: True if there are any documents needing KG processing, False otherwise
|
||||
"""
|
||||
|
||||
stmt = (
|
||||
select(1)
|
||||
.select_from(DbDocument)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DbDocument.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
DocumentByConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
Connector.kg_processing_enabled.is_(True),
|
||||
DbDocument.doc_updated_at >= kg_coverage_start,
|
||||
DbDocument.doc_updated_at
|
||||
>= datetime.now() - timedelta(days=kg_max_coverage_days),
|
||||
or_(
|
||||
or_(
|
||||
DbDocument.kg_stage.is_(None),
|
||||
DbDocument.kg_stage == KGStage.NOT_STARTED,
|
||||
),
|
||||
DbDocument.doc_updated_at > DbDocument.kg_processing_time,
|
||||
),
|
||||
)
|
||||
)
|
||||
.exists()
|
||||
)
|
||||
|
||||
return db_session.execute(select(stmt)).scalar() or False
|
||||
|
||||
|
||||
def check_for_documents_needing_kg_clustering(db_session: Session) -> bool:
|
||||
"""Check if there are any documents that need KG clustering.
|
||||
|
||||
A document needs KG clustering if:
|
||||
1. It is associated with a connector that has kg_processing_enabled = true
|
||||
2. AND either:
|
||||
- Its kg_stage is EXTRACTED
|
||||
- OR its last_updated timestamp is greater than its kg_processing_time
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
|
||||
Returns:
|
||||
bool: True if there are any documents needing KG clustering, False otherwise
|
||||
"""
|
||||
stmt = (
|
||||
select(1)
|
||||
.select_from(DbDocument)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DbDocument.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
Connector.kg_processing_enabled.is_(True),
|
||||
ConnectorCredentialPair.status
|
||||
!= ConnectorCredentialPairStatus.DELETING,
|
||||
or_(
|
||||
DbDocument.kg_stage == KGStage.EXTRACTED,
|
||||
DbDocument.last_modified > DbDocument.kg_processing_time,
|
||||
),
|
||||
)
|
||||
)
|
||||
.exists()
|
||||
)
|
||||
|
||||
return db_session.execute(select(stmt)).scalar() or False
|
||||
|
||||
|
||||
def get_document_kg_entities_and_relationships(
|
||||
db_session: Session, document_id: str
|
||||
) -> tuple[list[KGEntity], list[KGRelationship]]:
|
||||
"""
|
||||
Get the KG entities and relationships that references the document.
|
||||
"""
|
||||
entities = (
|
||||
db_session.query(KGEntity).filter(KGEntity.document_id == document_id).all()
|
||||
)
|
||||
if not entities:
|
||||
return [], []
|
||||
entity_id_names = [entity.id_name for entity in entities]
|
||||
|
||||
relationships = (
|
||||
db_session.query(KGRelationship)
|
||||
.filter(
|
||||
or_(
|
||||
KGRelationship.source_node.in_(entity_id_names),
|
||||
KGRelationship.target_node.in_(entity_id_names),
|
||||
KGRelationship.source_document == document_id,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return entities, relationships
|
||||
|
||||
@@ -27,6 +27,8 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from onyx.configs.app_configs import AWS_REGION_NAME
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
@@ -187,7 +189,9 @@ def is_valid_schema_name(name: str) -> bool:
|
||||
|
||||
class SqlEngine:
|
||||
_engine: Engine | None = None
|
||||
_readonly_engine: Engine | None = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_readonly_lock: threading.Lock = threading.Lock()
|
||||
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
||||
|
||||
@classmethod
|
||||
@@ -252,12 +256,80 @@ class SqlEngine:
|
||||
|
||||
cls._engine = engine
|
||||
|
||||
@classmethod
|
||||
def init_readonly_engine(
|
||||
cls,
|
||||
pool_size: int,
|
||||
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
|
||||
max_overflow: int,
|
||||
**extra_engine_kwargs: Any,
|
||||
) -> None:
|
||||
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
|
||||
important args, and if incorrectly specified, we have run into hitting the pool
|
||||
limit / using too many connections and overwhelming the database."""
|
||||
with cls._readonly_lock:
|
||||
if cls._readonly_engine:
|
||||
return
|
||||
|
||||
if not DB_READONLY_USER or not DB_READONLY_PASSWORD:
|
||||
raise ValueError(
|
||||
"Custom database user credentials not configured in environment variables"
|
||||
)
|
||||
|
||||
# Build connection string with custom user
|
||||
connection_string = build_connection_string(
|
||||
user=DB_READONLY_USER,
|
||||
password=DB_READONLY_PASSWORD,
|
||||
use_iam_auth=False, # Custom users typically don't use IAM auth
|
||||
db_api=SYNC_DB_API, # Explicitly use sync DB API
|
||||
)
|
||||
|
||||
# Start with base kwargs that are valid for all pool types
|
||||
final_engine_kwargs: dict[str, Any] = {}
|
||||
|
||||
if POSTGRES_USE_NULL_POOL:
|
||||
# if null pool is specified, then we need to make sure that
|
||||
# we remove any passed in kwargs related to pool size that would
|
||||
# cause the initialization to fail
|
||||
final_engine_kwargs.update(extra_engine_kwargs)
|
||||
|
||||
final_engine_kwargs["poolclass"] = pool.NullPool
|
||||
if "pool_size" in final_engine_kwargs:
|
||||
del final_engine_kwargs["pool_size"]
|
||||
if "max_overflow" in final_engine_kwargs:
|
||||
del final_engine_kwargs["max_overflow"]
|
||||
else:
|
||||
final_engine_kwargs["pool_size"] = pool_size
|
||||
final_engine_kwargs["max_overflow"] = max_overflow
|
||||
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
|
||||
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
|
||||
|
||||
# any passed in kwargs override the defaults
|
||||
final_engine_kwargs.update(extra_engine_kwargs)
|
||||
|
||||
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
||||
# echo=True here for inspecting all emitted db queries
|
||||
engine = create_engine(connection_string, **final_engine_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
event.listen(engine, "do_connect", provide_iam_token)
|
||||
|
||||
cls._readonly_engine = engine
|
||||
|
||||
@classmethod
|
||||
def get_engine(cls) -> Engine:
|
||||
if not cls._engine:
|
||||
raise RuntimeError("Engine not initialized. Must call init_engine first.")
|
||||
return cls._engine
|
||||
|
||||
@classmethod
|
||||
def get_readonly_engine(cls) -> Engine:
|
||||
if not cls._readonly_engine:
|
||||
raise RuntimeError(
|
||||
"Readonly engine not initialized. Must call init_readonly_engine first."
|
||||
)
|
||||
return cls._readonly_engine
|
||||
|
||||
@classmethod
|
||||
def set_app_name(cls, app_name: str) -> None:
|
||||
cls._app_name = app_name
|
||||
@@ -307,6 +379,10 @@ def get_sqlalchemy_engine() -> Engine:
|
||||
return SqlEngine.get_engine()
|
||||
|
||||
|
||||
def get_readonly_sqlalchemy_engine() -> Engine:
|
||||
return SqlEngine.get_readonly_engine()
|
||||
|
||||
|
||||
async def get_async_connection() -> Any:
|
||||
"""
|
||||
Custom connection function for async engine when using IAM auth.
|
||||
@@ -444,6 +520,9 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reset search path: {e}")
|
||||
connection.rollback()
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@@ -561,3 +640,49 @@ def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) ->
|
||||
region = os.getenv("AWS_REGION_NAME", "us-east-2")
|
||||
# Configure for psycopg2 with IAM token
|
||||
configure_psycopg2_iam_auth(cparams, host, port, user, region)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_readonly_user_session_with_current_tenant() -> (
|
||||
Generator[Session, None, None]
|
||||
):
|
||||
"""
|
||||
Generate a database session using a custom database user for the current tenant.
|
||||
The custom user credentials are obtained from environment variables.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
readonly_engine = get_readonly_sqlalchemy_engine()
|
||||
|
||||
event.listen(readonly_engine, "checkout", _set_search_path_on_checkout__listener)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
with readonly_engine.connect() as connection:
|
||||
dbapi_connection = connection.connection
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
cursor.execute(
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
if MULTI_TENANT:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reset search path: {e}")
|
||||
connection.rollback()
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
309
backend/onyx/db/entities.py
Normal file
309
backend/onyx/db/entities.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import onyx.db.document as dbdocument
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import KGEntity
|
||||
from onyx.db.models import KGEntityExtractionStaging
|
||||
from onyx.db.models import KGEntityType
|
||||
from onyx.kg.models import KGGroundingType
|
||||
from onyx.kg.models import KGStage
|
||||
from onyx.kg.utils.formatting_utils import make_entity_id
|
||||
|
||||
|
||||
def upsert_staging_entity(
|
||||
db_session: Session,
|
||||
name: str,
|
||||
entity_type: str,
|
||||
document_id: str | None = None,
|
||||
occurrences: int = 1,
|
||||
attributes: dict[str, str] | None = None,
|
||||
event_time: datetime | None = None,
|
||||
) -> KGEntityExtractionStaging:
|
||||
"""Add or update a new staging entity to the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
name: Name of the entity
|
||||
entity_type: Type of the entity (must match an existing KGEntityType)
|
||||
document_id: ID of the document the entity belongs to
|
||||
occurrences: Number of times this entity has been found
|
||||
attributes: Attributes of the entity
|
||||
event_time: Time the entity was added to the database
|
||||
|
||||
Returns:
|
||||
KGEntityExtractionStaging: The created entity
|
||||
"""
|
||||
entity_type = entity_type.upper()
|
||||
name = name.title()
|
||||
id_name = make_entity_id(entity_type, name)
|
||||
attributes = attributes or {}
|
||||
|
||||
entity_type_split = entity_type.split("-")
|
||||
entity_class, entity_subtype = (
|
||||
entity_type_split if len(entity_type_split) == 2 else (entity_type, None)
|
||||
)
|
||||
|
||||
entity_key = attributes.get("key")
|
||||
entity_parent = attributes.get("parent")
|
||||
|
||||
keep_attributes = {
|
||||
attr_key: attr_val
|
||||
for attr_key, attr_val in attributes.items()
|
||||
if not (
|
||||
(attr_key in ("key", "parent") and entity_class)
|
||||
or attr_key in ("object_type", "issuetype")
|
||||
or "_email" in attr_key
|
||||
)
|
||||
}
|
||||
|
||||
# Create new entity
|
||||
stmt = (
|
||||
pg_insert(KGEntityExtractionStaging)
|
||||
.values(
|
||||
id_name=id_name,
|
||||
name=name,
|
||||
entity_type_id_name=entity_type,
|
||||
entity_class=entity_class,
|
||||
entity_subtype=entity_subtype,
|
||||
entity_key=entity_key,
|
||||
parent_key=entity_parent,
|
||||
document_id=document_id,
|
||||
occurrences=occurrences,
|
||||
attributes=keep_attributes,
|
||||
event_time=event_time,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["id_name"],
|
||||
set_=dict(
|
||||
occurrences=KGEntityExtractionStaging.occurrences + occurrences,
|
||||
),
|
||||
)
|
||||
.returning(KGEntityExtractionStaging)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt).scalar()
|
||||
if result is None:
|
||||
raise RuntimeError(
|
||||
f"Failed to create or increment staging entity with id_name: {id_name}"
|
||||
)
|
||||
|
||||
# Update the document's kg_stage if document_id is provided
|
||||
if document_id is not None:
|
||||
db_session.query(Document).filter(Document.id == document_id).update(
|
||||
{
|
||||
"kg_stage": KGStage.EXTRACTED,
|
||||
"kg_processing_time": datetime.now(timezone.utc),
|
||||
}
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def transfer_entity(
|
||||
db_session: Session,
|
||||
entity: KGEntityExtractionStaging,
|
||||
) -> KGEntity:
|
||||
"""Transfer an entity from the extraction staging table to the normalized table.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity: Entity to transfer
|
||||
|
||||
Returns:
|
||||
KGEntity: The transferred entity
|
||||
"""
|
||||
# Create the transferred entity
|
||||
stmt = (
|
||||
pg_insert(KGEntity)
|
||||
.values(
|
||||
id_name=make_entity_id(entity.entity_type_id_name, uuid.uuid4().hex[:20]),
|
||||
name=entity.name.casefold(),
|
||||
entity_class=entity.entity_class,
|
||||
entity_subtype=entity.entity_subtype,
|
||||
entity_key=entity.entity_key,
|
||||
alternative_names=entity.alternative_names or [],
|
||||
entity_type_id_name=entity.entity_type_id_name,
|
||||
document_id=entity.document_id,
|
||||
occurrences=entity.occurrences,
|
||||
attributes=entity.attributes,
|
||||
event_time=entity.event_time,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["name", "entity_type_id_name", "document_id"],
|
||||
set_=dict(
|
||||
occurrences=KGEntity.occurrences + entity.occurrences,
|
||||
),
|
||||
)
|
||||
.returning(KGEntity)
|
||||
)
|
||||
new_entity = db_session.execute(stmt).scalar()
|
||||
if new_entity is None:
|
||||
raise RuntimeError(f"Failed to transfer entity with id_name: {entity.id_name}")
|
||||
|
||||
# Update the document's kg_stage if document_id is provided
|
||||
if entity.document_id is not None:
|
||||
dbdocument.update_document_kg_info(
|
||||
db_session,
|
||||
document_id=entity.document_id,
|
||||
kg_stage=KGStage.NORMALIZED,
|
||||
)
|
||||
|
||||
# Update transferred
|
||||
db_session.query(KGEntityExtractionStaging).filter(
|
||||
KGEntityExtractionStaging.id_name == entity.id_name
|
||||
).update({"transferred_id_name": new_entity.id_name})
|
||||
db_session.flush()
|
||||
|
||||
return new_entity
|
||||
|
||||
|
||||
def merge_entities(
|
||||
db_session: Session, parent: KGEntity, child: KGEntityExtractionStaging
|
||||
) -> KGEntity:
|
||||
"""Merge an entity from the extraction staging table into
|
||||
an existing entity in the normalized table.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
parent: Parent entity to merge into
|
||||
child: Child staging entity to merge
|
||||
|
||||
Returns:
|
||||
KGEntity: The merged entity
|
||||
"""
|
||||
# check we're not merging two entities with different document_ids
|
||||
if (
|
||||
parent.document_id is not None
|
||||
and child.document_id is not None
|
||||
and parent.document_id != child.document_id
|
||||
):
|
||||
raise ValueError(
|
||||
"Overwriting the document_id of an entity with a document_id already is not allowed"
|
||||
)
|
||||
|
||||
# update the parent entity (only document_id, alternative_names, occurrences)
|
||||
setting_doc = parent.document_id is None and child.document_id is not None
|
||||
document_id = child.document_id if setting_doc else parent.document_id
|
||||
alternative_names = set(parent.alternative_names or [])
|
||||
alternative_names.update(child.alternative_names or [])
|
||||
alternative_names.add(child.name.lower())
|
||||
alternative_names.discard(parent.name)
|
||||
|
||||
stmt = (
|
||||
update(KGEntity)
|
||||
.where(KGEntity.id_name == parent.id_name)
|
||||
.values(
|
||||
document_id=document_id,
|
||||
alternative_names=list(alternative_names),
|
||||
occurrences=parent.occurrences + child.occurrences,
|
||||
)
|
||||
.returning(KGEntity)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt).scalar()
|
||||
if result is None:
|
||||
raise RuntimeError(f"Failed to merge entities with id_name: {parent.id_name}")
|
||||
|
||||
# Update the document's kg_stage if document_id is set
|
||||
if setting_doc and child.document_id is not None:
|
||||
dbdocument.update_document_kg_info(
|
||||
db_session,
|
||||
document_id=child.document_id,
|
||||
kg_stage=KGStage.NORMALIZED,
|
||||
)
|
||||
|
||||
# Update transferred
|
||||
db_session.query(KGEntityExtractionStaging).filter(
|
||||
KGEntityExtractionStaging.id_name == child.id_name
|
||||
).update({"transferred_id_name": parent.id_name})
|
||||
db_session.flush()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_kg_entity_by_document(db: Session, document_id: str) -> KGEntity | None:
|
||||
"""
|
||||
Check if a document_id exists in the kg_entities table and return its id_name if found.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
document_id: The document ID to search for
|
||||
|
||||
Returns:
|
||||
The id_name of the matching KGEntity if found, None otherwise
|
||||
"""
|
||||
query = select(KGEntity).where(KGEntity.document_id == document_id)
|
||||
result = db.execute(query).scalar()
|
||||
return result
|
||||
|
||||
|
||||
def get_grounded_entities_by_types(
|
||||
db_session: Session, entity_types: List[str], grounding: KGGroundingType
|
||||
) -> List[KGEntity]:
|
||||
"""Get all entities matching an entity_type.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity_types: List of entity types to filter by
|
||||
|
||||
Returns:
|
||||
List of KGEntity objects belonging to the specified entity types
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntity)
|
||||
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
|
||||
.filter(KGEntity.entity_type_id_name.in_(entity_types))
|
||||
.filter(KGEntityType.grounding == grounding)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_document_id_for_entity(db_session: Session, entity_id_name: str) -> str | None:
|
||||
"""Get the document ID associated with an entity.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
entity_id_name: The entity id_name to look up
|
||||
|
||||
Returns:
|
||||
The document ID if found, None otherwise
|
||||
"""
|
||||
entity = (
|
||||
db_session.query(KGEntity).filter(KGEntity.id_name == entity_id_name).first()
|
||||
)
|
||||
return entity.document_id if entity else None
|
||||
|
||||
|
||||
def delete_from_kg_entities_extraction_staging__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""Delete entities from the extraction staging table."""
|
||||
db_session.query(KGEntityExtractionStaging).filter(
|
||||
KGEntityExtractionStaging.document_id.in_(document_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def delete_from_kg_entities__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""Delete entities from the normalized table."""
|
||||
db_session.query(KGEntity).filter(KGEntity.document_id.in_(document_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
|
||||
def get_entity_name(db_session: Session, entity_id_name: str) -> str | None:
|
||||
"""Get the name of an entity."""
|
||||
entity = (
|
||||
db_session.query(KGEntity).filter(KGEntity.id_name == entity_id_name).first()
|
||||
)
|
||||
return entity.name if entity else None
|
||||
257
backend/onyx/db/entity_type.py
Normal file
257
backend/onyx/db/entity_type.py
Normal file
@@ -0,0 +1,257 @@
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.models import KGEntityType
|
||||
from onyx.kg.kg_default_entity_definitions import KGDefaultAccountEmployeeDefinitions
|
||||
from onyx.kg.kg_default_entity_definitions import (
|
||||
KGDefaultPrimaryGroundedEntityDefinitions,
|
||||
)
|
||||
from onyx.kg.models import KGGroundingType
|
||||
|
||||
|
||||
def get_determined_grounded_entity_types(db_session: Session) -> List[KGEntityType]:
|
||||
"""Get all entity types that have non-null entity_values.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have entity_values defined
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.entity_values.isnot(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_grounded_entity_types(db_session: Session) -> List[KGEntityType]:
|
||||
"""Get all entity types that have grounding = GROUNDED.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have grounding = GROUNDED
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.grounding == KGGroundingType.GROUNDED)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_entity_types_with_grounded_source_name(
|
||||
db_session: Session,
|
||||
) -> List[KGEntityType]:
|
||||
"""Get all entity types that have non-null grounded_source_name.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have grounded_source_name defined
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.grounded_source_name.isnot(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_entity_type_by_grounded_source_name(
|
||||
db_session: Session, grounded_source_name: KGGroundingType
|
||||
) -> KGEntityType | None:
|
||||
"""Get an entity type by its grounded_source_name and return it as a dictionary.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
grounded_source_name: The grounded_source_name of the entity to retrieve
|
||||
|
||||
Returns:
|
||||
Dictionary containing the entity's data with column names as keys,
|
||||
or None if the entity is not found
|
||||
"""
|
||||
entity_type = (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.grounded_source_name == grounded_source_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if entity_type is None:
|
||||
return None
|
||||
|
||||
return entity_type
|
||||
|
||||
|
||||
def get_entity_types(
|
||||
db_session: Session,
|
||||
active: bool | None = True,
|
||||
) -> list[KGEntityType]:
|
||||
# Query the database for all distinct entity types
|
||||
|
||||
if active is None:
|
||||
return db_session.query(KGEntityType).order_by(KGEntityType.id_name).all()
|
||||
|
||||
else:
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.active == active)
|
||||
.order_by(KGEntityType.id_name)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def populate_default_primary_grounded_entity_type_information(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Populate the entity type information for the KG.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
"""
|
||||
|
||||
# get kg config information
|
||||
kg_config_settings = get_kg_config_settings(db_session)
|
||||
|
||||
if not kg_config_settings.KG_ENABLED:
|
||||
raise ValueError("KG is not enabled")
|
||||
if not kg_config_settings.KG_VENDOR:
|
||||
raise ValueError("KG_VENDOR is not set")
|
||||
if not kg_config_settings.KG_VENDOR_DOMAINS:
|
||||
raise ValueError("KG_VENDOR_DOMAINS is not set")
|
||||
|
||||
# Get all existing entity types
|
||||
existing_entity_types = {et.id_name for et in db_session.query(KGEntityType).all()}
|
||||
|
||||
# Create an instance of the default definitions
|
||||
default_definitions = KGDefaultPrimaryGroundedEntityDefinitions()
|
||||
|
||||
# Iterate over all attributes in the default definitions
|
||||
for id_name, definition in default_definitions.model_dump().items():
|
||||
# Skip if this entity type already exists
|
||||
if id_name in existing_entity_types:
|
||||
continue
|
||||
|
||||
# Create new entity type
|
||||
|
||||
description = definition["description"].replace(
|
||||
"---vendor_name---", kg_config_settings.KG_VENDOR
|
||||
)
|
||||
|
||||
new_entity_type = KGEntityType(
|
||||
id_name=id_name,
|
||||
description=description,
|
||||
grounding=definition["grounding"],
|
||||
grounded_source_name=definition["grounded_source_name"],
|
||||
active=False,
|
||||
)
|
||||
|
||||
# Add to session
|
||||
db_session.add(new_entity_type)
|
||||
|
||||
# Commit changes
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def populate_default_employee_account_information(db_session: Session) -> None:
|
||||
"""Populate the entity type information for the KG.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
"""
|
||||
|
||||
# get kg config information
|
||||
kg_config_settings = get_kg_config_settings(db_session)
|
||||
|
||||
if not kg_config_settings.KG_ENABLED:
|
||||
raise ValueError("KG is not enabled")
|
||||
if not kg_config_settings.KG_VENDOR:
|
||||
raise ValueError("KG_VENDOR is not set")
|
||||
if not kg_config_settings.KG_VENDOR_DOMAINS:
|
||||
raise ValueError("KG_VENDOR_DOMAINS is not set")
|
||||
|
||||
# Get all existing entity types
|
||||
existing_entity_types = {et.id_name for et in db_session.query(KGEntityType).all()}
|
||||
|
||||
# Create an instance of the default definitions
|
||||
default_definitions = KGDefaultAccountEmployeeDefinitions()
|
||||
|
||||
# Iterate over all attributes in the default definitions
|
||||
for id_name, definition in default_definitions.model_dump().items():
|
||||
# Skip if this entity type already exists
|
||||
if id_name in existing_entity_types:
|
||||
continue
|
||||
|
||||
# Create new entity type
|
||||
description = definition["description"].replace(
|
||||
"---vendor_name---", kg_config_settings.KG_VENDOR
|
||||
)
|
||||
new_entity_type = KGEntityType(
|
||||
id_name=id_name,
|
||||
description=description,
|
||||
grounding=definition["grounding"],
|
||||
grounded_source_name=definition["grounded_source_name"],
|
||||
active=definition["active"],
|
||||
)
|
||||
|
||||
# Add to session
|
||||
db_session.add(new_entity_type)
|
||||
|
||||
# Commit changes
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def get_grounded_entity_types_with_null_grounded_source(
|
||||
db_session: Session,
|
||||
) -> List[KGEntityType]:
|
||||
"""Get all entity types that have null grounded_source_name and grounding = GROUNDED.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have null grounded_source_name and grounding = GROUNDED
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.grounded_source_name.is_(None))
|
||||
.filter(KGEntityType.grounding == KGGroundingType.GROUNDED)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_entity_types_by_grounding(
|
||||
db_session: Session,
|
||||
grounding: KGGroundingType,
|
||||
) -> List[KGEntityType]:
|
||||
"""Get all entity types that have a specific grounding.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
grounding: The grounding type to filter by
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have the specified grounding
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType).filter(KGEntityType.grounding == grounding).all()
|
||||
)
|
||||
|
||||
|
||||
def get_grounded_source_name(db_session: Session, entity_type: str) -> str | None:
|
||||
"""
|
||||
Get the grounded source name for an entity type.
|
||||
"""
|
||||
|
||||
result = (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.id_name == entity_type)
|
||||
.first()
|
||||
)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return result.grounded_source_name
|
||||
141
backend/onyx/db/kg_config.py
Normal file
141
backend/onyx/db/kg_config.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import KGConfig
|
||||
from onyx.kg.models import KGConfigSettings
|
||||
from onyx.kg.models import KGConfigVars
|
||||
|
||||
|
||||
class KGProcessingType(Enum):
|
||||
|
||||
EXTRACTION = "extraction"
|
||||
CLUSTERING = "clustering"
|
||||
|
||||
|
||||
def get_kg_enablement(db_session: Session) -> bool:
|
||||
check = (
|
||||
db_session.query(KGConfig.kg_variable_values)
|
||||
.filter(
|
||||
KGConfig.kg_variable_name == "KG_ENABLED"
|
||||
and KGConfig.kg_variable_values == ["true"]
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return check is not None
|
||||
|
||||
|
||||
def get_kg_config_settings(db_session: Session) -> KGConfigSettings:
|
||||
results = db_session.query(KGConfig).all()
|
||||
|
||||
kg_config_settings = KGConfigSettings()
|
||||
for result in results:
|
||||
if result.kg_variable_name == KGConfigVars.KG_ENABLED:
|
||||
kg_config_settings.KG_ENABLED = result.kg_variable_values[0] == "true"
|
||||
elif result.kg_variable_name == KGConfigVars.KG_VENDOR:
|
||||
if len(result.kg_variable_values) > 0:
|
||||
kg_config_settings.KG_VENDOR = result.kg_variable_values[0]
|
||||
else:
|
||||
kg_config_settings.KG_VENDOR = None
|
||||
elif result.kg_variable_name == KGConfigVars.KG_VENDOR_DOMAINS:
|
||||
kg_config_settings.KG_VENDOR_DOMAINS = result.kg_variable_values
|
||||
elif result.kg_variable_name == KGConfigVars.KG_IGNORE_EMAIL_DOMAINS:
|
||||
kg_config_settings.KG_IGNORE_EMAIL_DOMAINS = result.kg_variable_values
|
||||
elif result.kg_variable_name == KGConfigVars.KG_COVERAGE_START:
|
||||
kg_coverage_start_str = result.kg_variable_values[0] or "1970-01-01"
|
||||
|
||||
kg_config_settings.KG_COVERAGE_START = datetime.strptime(
|
||||
kg_coverage_start_str, "%Y-%m-%d"
|
||||
)
|
||||
|
||||
elif result.kg_variable_name == KGConfigVars.KG_MAX_COVERAGE_DAYS:
|
||||
kg_max_coverage_days_str = result.kg_variable_values[0]
|
||||
if not kg_max_coverage_days_str.isdigit():
|
||||
raise ValueError(
|
||||
f"KG_MAX_COVERAGE_DAYS is not a number: {kg_max_coverage_days_str}"
|
||||
)
|
||||
kg_config_settings.KG_MAX_COVERAGE_DAYS = max(
|
||||
0, int(kg_max_coverage_days_str)
|
||||
)
|
||||
|
||||
elif result.kg_variable_name == KGConfigVars.KG_EXTRACTION_IN_PROGRESS:
|
||||
kg_config_settings.KG_EXTRACTION_IN_PROGRESS = (
|
||||
result.kg_variable_values[0] == "true"
|
||||
)
|
||||
elif result.kg_variable_name == KGConfigVars.KG_CLUSTERING_IN_PROGRESS:
|
||||
kg_config_settings.KG_CLUSTERING_IN_PROGRESS = (
|
||||
result.kg_variable_values[0] == "true"
|
||||
)
|
||||
elif result.kg_variable_name == KGConfigVars.KG_MAX_PARENT_RECURSION_DEPTH:
|
||||
kg_max_parent_recursion_depth_str = result.kg_variable_values[0]
|
||||
if not kg_max_parent_recursion_depth_str.isdigit():
|
||||
raise ValueError(
|
||||
f"KG_MAX_PARENT_RECURSION_DEPTH is not a number: {kg_max_parent_recursion_depth_str}"
|
||||
)
|
||||
kg_config_settings.KG_MAX_PARENT_RECURSION_DEPTH = max(
|
||||
0, int(kg_max_parent_recursion_depth_str)
|
||||
)
|
||||
elif result.kg_variable_name == KGConfigVars.KG_EXPOSED:
|
||||
kg_config_settings.KG_EXPOSED = result.kg_variable_values[0] == "true"
|
||||
|
||||
return kg_config_settings
|
||||
|
||||
|
||||
def set_kg_processing_in_progress_status(
|
||||
db_session: Session, processing_type: KGProcessingType, in_progress: bool
|
||||
) -> None:
|
||||
"""
|
||||
Set the KG_EXTRACTION_IN_PROGRESS or KG_CLUSTERING_IN_PROGRESS configuration values.
|
||||
|
||||
Args:
|
||||
db_session: The database session to use
|
||||
in_progress: Whether KG processing is in progress (True) or not (False)
|
||||
"""
|
||||
# Convert boolean to string and wrap in list as required by the model
|
||||
value = [str(in_progress).lower()]
|
||||
kg_variable_name = "KG_EXTRACTION_IN_PROGRESS" # Default value
|
||||
|
||||
if processing_type == KGProcessingType.CLUSTERING:
|
||||
kg_variable_name = "KG_CLUSTERING_IN_PROGRESS"
|
||||
|
||||
# Use PostgreSQL's upsert functionality
|
||||
stmt = (
|
||||
pg_insert(KGConfig)
|
||||
.values(kg_variable_name=str(kg_variable_name), kg_variable_values=value)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["kg_variable_name"], set_=dict(kg_variable_values=value)
|
||||
)
|
||||
)
|
||||
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def get_kg_processing_in_progress_status(
|
||||
db_session: Session, processing_type: KGProcessingType
|
||||
) -> bool:
|
||||
"""
|
||||
Get the current KG_EXTRACTION_IN_PROGRESS or KG_CLUSTERING_IN_PROGRESS configuration value.
|
||||
|
||||
Args:
|
||||
db_session: The database session to use
|
||||
|
||||
Returns:
|
||||
bool: True if KG processing is in progress, False otherwise
|
||||
"""
|
||||
|
||||
kg_variable_name = "KG_EXTRACTION_IN_PROGRESS" # Default value
|
||||
if processing_type == KGProcessingType.CLUSTERING:
|
||||
kg_variable_name = "KG_CLUSTERING_IN_PROGRESS"
|
||||
|
||||
config = (
|
||||
db_session.query(KGConfig)
|
||||
.filter(KGConfig.kg_variable_name == kg_variable_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not config:
|
||||
return False
|
||||
|
||||
return config.kg_variable_values[0] == "true"
|
||||
167
backend/onyx/db/kg_temp_view.py
Normal file
167
backend/onyx/db/kg_temp_view.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_user_view_names(user_email: str) -> tuple[str, str]:
|
||||
user_email_cleaned = user_email.replace("@", "_").replace(".", "_")
|
||||
return (
|
||||
f"allowed_docs_{user_email_cleaned}",
|
||||
f"kg_relationships_with_access_{user_email_cleaned}",
|
||||
)
|
||||
|
||||
|
||||
# First, create the view definition
|
||||
def create_views(
|
||||
db_session: Session,
|
||||
user_email: str,
|
||||
allowed_docs_view_name: str = "allowed_docs",
|
||||
kg_relationships_view_name: str = "kg_relationships_with_access",
|
||||
) -> None:
|
||||
# Create ALLOWED_DOCS view
|
||||
allowed_docs_view = text(
|
||||
f"""
|
||||
CREATE OR REPLACE VIEW {allowed_docs_view_name} AS
|
||||
WITH kg_used_docs AS (
|
||||
SELECT document_id as kg_used_doc_id
|
||||
FROM kg_entity d
|
||||
WHERE document_id IS NOT NULL
|
||||
),
|
||||
|
||||
public_docs AS (
|
||||
SELECT d.id as allowed_doc_id
|
||||
FROM document d
|
||||
INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id
|
||||
WHERE d.is_public
|
||||
),
|
||||
user_owned_docs AS (
|
||||
SELECT d.id as allowed_doc_id
|
||||
FROM document_by_connector_credential_pair d
|
||||
JOIN credential c ON d.credential_id = c.id
|
||||
JOIN connector_credential_pair ccp ON
|
||||
d.connector_id = ccp.connector_id AND
|
||||
d.credential_id = ccp.credential_id
|
||||
JOIN "user" u ON c.user_id = u.id
|
||||
INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id
|
||||
WHERE ccp.status != 'DELETING'
|
||||
AND ccp.access_type != 'SYNC'
|
||||
AND u.email = :user_email
|
||||
),
|
||||
user_group_accessible_docs AS (
|
||||
SELECT d.id as allowed_doc_id
|
||||
FROM document_by_connector_credential_pair d
|
||||
JOIN connector_credential_pair ccp ON
|
||||
d.connector_id = ccp.connector_id AND
|
||||
d.credential_id = ccp.credential_id
|
||||
JOIN user_group__connector_credential_pair ugccp ON
|
||||
ccp.id = ugccp.cc_pair_id
|
||||
JOIN user__user_group uug ON
|
||||
uug.user_group_id = ugccp.user_group_id
|
||||
JOIN "user" u ON uug.user_id = u.id
|
||||
INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id
|
||||
WHERE kud.kg_used_doc_id IS NOT NULL
|
||||
AND ccp.status != 'DELETING'
|
||||
AND ccp.access_type != 'SYNC'
|
||||
AND u.email = :user_email
|
||||
),
|
||||
external_user_docs AS (
|
||||
SELECT d.id as allowed_doc_id
|
||||
FROM document d
|
||||
INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id
|
||||
WHERE kud.kg_used_doc_id IS NOT NULL
|
||||
AND :user_email = ANY(external_user_emails)
|
||||
),
|
||||
external_group_docs AS (
|
||||
SELECT d.id as allowed_doc_id
|
||||
FROM document d
|
||||
INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id
|
||||
JOIN user__external_user_group_id ueg ON ueg.external_user_group_id = ANY(d.external_user_group_ids)
|
||||
JOIN "user" u ON ueg.user_id = u.id
|
||||
WHERE kud.kg_used_doc_id IS NOT NULL
|
||||
AND u.email = :user_email
|
||||
)
|
||||
SELECT DISTINCT allowed_doc_id FROM (
|
||||
SELECT allowed_doc_id FROM public_docs
|
||||
UNION
|
||||
SELECT allowed_doc_id FROM user_owned_docs
|
||||
UNION
|
||||
SELECT allowed_doc_id FROM user_group_accessible_docs
|
||||
UNION
|
||||
SELECT allowed_doc_id FROM external_user_docs
|
||||
UNION
|
||||
SELECT allowed_doc_id FROM external_group_docs
|
||||
) combined_docs
|
||||
"""
|
||||
).bindparams(user_email=user_email)
|
||||
|
||||
# Create the main view that uses ALLOWED_DOCS
|
||||
kg_relationships_view = text(
|
||||
f"""
|
||||
CREATE OR REPLACE VIEW {kg_relationships_view_name} AS
|
||||
SELECT kgr.id_name as relationship,
|
||||
kgr.source_node as source_entity,
|
||||
kgr.target_node as target_entity,
|
||||
kgr.source_node_type as source_entity_type,
|
||||
kgr.target_node_type as target_entity_type,
|
||||
kgr.type as relationship_description,
|
||||
kgr.relationship_type_id_name as relationship_type,
|
||||
kgr.source_document as source_document,
|
||||
d.doc_updated_at as source_date,
|
||||
se.attributes as source_entity_attributes,
|
||||
te.attributes as target_entity_attributes
|
||||
FROM kg_relationship kgr
|
||||
INNER JOIN {allowed_docs_view_name} AD on AD.allowed_doc_id = kgr.source_document
|
||||
JOIN document d on d.id = kgr.source_document
|
||||
JOIN kg_entity se on se.id_name = kgr.source_node
|
||||
JOIN kg_entity te on te.id_name = kgr.target_node
|
||||
"""
|
||||
)
|
||||
|
||||
# Execute the views using the session
|
||||
db_session.execute(allowed_docs_view)
|
||||
db_session.execute(kg_relationships_view)
|
||||
|
||||
# Grant permissions on view to readonly user
|
||||
|
||||
db_session.execute(
|
||||
text(f"GRANT SELECT ON {kg_relationships_view_name} TO {DB_READONLY_USER}")
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def drop_views(
|
||||
db_session: Session,
|
||||
allowed_docs_view_name: str = "allowed_docs",
|
||||
kg_relationships_view_name: str = "kg_relationships_with_access",
|
||||
) -> None:
|
||||
"""
|
||||
Drops the temporary views created by create_views.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
allowed_docs_view_name: Name of the allowed_docs view
|
||||
kg_relationships_view_name: Name of the kg_relationships view
|
||||
"""
|
||||
# First revoke access from the readonly user
|
||||
revoke_kg_relationships = text(
|
||||
f"REVOKE SELECT ON {kg_relationships_view_name} FROM {DB_READONLY_USER}"
|
||||
)
|
||||
|
||||
db_session.execute(revoke_kg_relationships)
|
||||
|
||||
# Drop the views in reverse order of creation to handle dependencies
|
||||
drop_kg_relationships = text(f"DROP VIEW IF EXISTS {kg_relationships_view_name}")
|
||||
drop_allowed_docs = text(f"DROP VIEW IF EXISTS {allowed_docs_view_name}")
|
||||
|
||||
db_session.execute(drop_kg_relationships)
|
||||
db_session.execute(drop_allowed_docs)
|
||||
db_session.commit()
|
||||
return None
|
||||
@@ -39,6 +39,7 @@ from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.types import LargeBinary
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
@@ -69,6 +70,7 @@ from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.kg.models import KGStage
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
@@ -586,6 +588,17 @@ class Document(Base):
|
||||
)
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# tables for the knowledge graph data
|
||||
kg_stage: Mapped[KGStage] = mapped_column(
|
||||
Enum(KGStage, native_enum=False),
|
||||
comment="Status of knowledge graph extraction for this document",
|
||||
index=True,
|
||||
)
|
||||
|
||||
kg_processing_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="document"
|
||||
)
|
||||
@@ -604,12 +617,639 @@ class Document(Base):
|
||||
)
|
||||
|
||||
|
||||
class KGConfig(Base):
|
||||
__tablename__ = "kg_config"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
|
||||
kg_variable_name: Mapped[str] = mapped_column(NullFilteredString, nullable=False)
|
||||
kg_variable_values: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("kg_variable_name", name="uq_kg_config_variable_name"),
|
||||
)
|
||||
|
||||
|
||||
class KGEntityType(Base):
|
||||
__tablename__ = "kg_entity_type"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, nullable=False, index=True
|
||||
)
|
||||
|
||||
description: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
|
||||
|
||||
grounding: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=False, index=False
|
||||
)
|
||||
|
||||
attributes: Mapped[str] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=True,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Filtering based on document attribute",
|
||||
)
|
||||
|
||||
occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
|
||||
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
deep_extraction: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
grounded_source_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=False, index=False
|
||||
)
|
||||
|
||||
entity_values: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True, default=None
|
||||
)
|
||||
|
||||
clustering: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Clustering information for this entity type",
|
||||
)
|
||||
|
||||
|
||||
class KGRelationshipType(Base):
|
||||
__tablename__ = "kg_relationship_type"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
source_entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
target_entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
definition: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
comment="Whether this relationship type represents a definition",
|
||||
)
|
||||
|
||||
clustering: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Clustering information for this relationship type",
|
||||
)
|
||||
|
||||
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# Relationships to EntityType
|
||||
source_type: Mapped["KGEntityType"] = relationship(
|
||||
"KGEntityType",
|
||||
foreign_keys=[source_entity_type_id_name],
|
||||
backref="source_relationship_type",
|
||||
)
|
||||
target_type: Mapped["KGEntityType"] = relationship(
|
||||
"KGEntityType",
|
||||
foreign_keys=[target_entity_type_id_name],
|
||||
backref="target_relationship_type",
|
||||
)
|
||||
|
||||
|
||||
class KGRelationshipTypeExtractionStaging(Base):
|
||||
__tablename__ = "kg_relationship_type_extraction_staging"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
source_entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
target_entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
definition: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
comment="Whether this relationship type represents a definition",
|
||||
)
|
||||
|
||||
clustering: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Clustering information for this relationship type",
|
||||
)
|
||||
|
||||
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
|
||||
transferred: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
)
|
||||
|
||||
# Tracking fields
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# Relationships to EntityType
|
||||
source_type: Mapped["KGEntityType"] = relationship(
|
||||
"KGEntityType",
|
||||
foreign_keys=[source_entity_type_id_name],
|
||||
backref="source_relationship_type_staging",
|
||||
)
|
||||
target_type: Mapped["KGEntityType"] = relationship(
|
||||
"KGEntityType",
|
||||
foreign_keys=[target_entity_type_id_name],
|
||||
backref="target_relationship_type_staging",
|
||||
)
|
||||
|
||||
|
||||
class KGEntity(Base):
|
||||
__tablename__ = "kg_entity"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString, primary_key=True, index=True
|
||||
)
|
||||
|
||||
# Basic entity information
|
||||
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
entity_class: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=True
|
||||
)
|
||||
entity_key: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=True
|
||||
)
|
||||
entity_subtype: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=True
|
||||
)
|
||||
|
||||
name_trigrams: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String(3)),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
attributes: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Attributes for this entity",
|
||||
)
|
||||
|
||||
document_id: Mapped[str | None] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=True
|
||||
)
|
||||
|
||||
alternative_names: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Reference to KGEntityType
|
||||
entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationship to KGEntityType
|
||||
entity_type: Mapped["KGEntityType"] = relationship("KGEntityType", backref="entity")
|
||||
|
||||
description: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
keywords: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
|
||||
# Access control
|
||||
acl: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Boosts - using JSON for flexibility
|
||||
boosts: Mapped[dict] = mapped_column(postgresql.JSONB, nullable=False, default=dict)
|
||||
|
||||
event_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
comment="Time of the event being processed",
|
||||
)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Fixed column names in indexes
|
||||
Index("ix_entity_type_acl", entity_type_id_name, acl),
|
||||
Index("ix_entity_name_search", name, entity_type_id_name),
|
||||
)
|
||||
|
||||
|
||||
class KGEntityExtractionStaging(Base):
|
||||
__tablename__ = "kg_entity_extraction_staging"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Basic entity information
|
||||
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
attributes: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Attributes for this entity",
|
||||
)
|
||||
|
||||
document_id: Mapped[str | None] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=True
|
||||
)
|
||||
|
||||
alternative_names: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Reference to KGEntityType
|
||||
entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationship to KGEntityType
|
||||
entity_type: Mapped["KGEntityType"] = relationship(
|
||||
"KGEntityType", backref="entity_staging"
|
||||
)
|
||||
|
||||
description: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
keywords: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
|
||||
# Access control
|
||||
acl: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Boosts - using JSON for flexibility
|
||||
boosts: Mapped[dict] = mapped_column(postgresql.JSONB, nullable=False, default=dict)
|
||||
|
||||
transferred_id_name: Mapped[str | None] = mapped_column(
|
||||
NullFilteredString,
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Basic entity information
|
||||
entity_class: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=True
|
||||
)
|
||||
|
||||
# Basic entity information
|
||||
entity_key: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=True
|
||||
)
|
||||
|
||||
entity_subtype: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=True
|
||||
)
|
||||
|
||||
# Basic entity information
|
||||
parent_key: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=True
|
||||
)
|
||||
|
||||
event_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
comment="Time of the event being processed",
|
||||
)
|
||||
|
||||
# Tracking fields
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Fixed column names in indexes
|
||||
Index("ix_entity_type_acl", entity_type_id_name, acl),
|
||||
Index("ix_entity_name_search", name, entity_type_id_name),
|
||||
)
|
||||
|
||||
|
||||
class KGRelationship(Base):
|
||||
__tablename__ = "kg_relationship"
|
||||
|
||||
# Primary identifier - now part of composite key
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
source_document: Mapped[str | None] = mapped_column(
|
||||
NullFilteredString, ForeignKey("document.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
# Source and target nodes (foreign keys to Entity table)
|
||||
source_node: Mapped[str] = mapped_column(
|
||||
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
|
||||
)
|
||||
|
||||
target_node: Mapped[str] = mapped_column(
|
||||
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
|
||||
)
|
||||
|
||||
source_node_type: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
target_node_type: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationship type
|
||||
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
# Add new relationship type reference
|
||||
relationship_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_relationship_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Add the SQLAlchemy relationship property
|
||||
relationship_type: Mapped["KGRelationshipType"] = relationship(
|
||||
"KGRelationshipType", backref="relationship"
|
||||
)
|
||||
|
||||
occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# Relationships to Entity table
|
||||
source: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[source_node])
|
||||
target: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[target_node])
|
||||
document: Mapped["Document"] = relationship(
|
||||
"Document", foreign_keys=[source_document]
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Composite primary key
|
||||
PrimaryKeyConstraint("id_name", "source_document"),
|
||||
# Index for querying relationships by type
|
||||
Index("ix_kg_relationship_type", type),
|
||||
# Composite index for source/target queries
|
||||
Index("ix_kg_relationship_nodes", source_node, target_node),
|
||||
# Ensure unique relationships between nodes of a specific type
|
||||
UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_source_target_type",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class KGRelationshipExtractionStaging(Base):
|
||||
__tablename__ = "kg_relationship_extraction_staging"
|
||||
|
||||
# Primary identifier - now part of composite key
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
source_document: Mapped[str | None] = mapped_column(
|
||||
NullFilteredString, ForeignKey("document.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
# Source and target nodes (foreign keys to Entity table)
|
||||
source_node: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_extraction_staging.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
target_node: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_extraction_staging.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
source_node_type: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
target_node_type: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationship type
|
||||
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
# Add new relationship type reference
|
||||
relationship_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_relationship_type_extraction_staging.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Add the SQLAlchemy relationship property
|
||||
relationship_type: Mapped["KGRelationshipTypeExtractionStaging"] = relationship(
|
||||
"KGRelationshipTypeExtractionStaging", backref="relationship_staging"
|
||||
)
|
||||
|
||||
occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
|
||||
transferred: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
)
|
||||
|
||||
# Tracking fields
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# Relationships to Entity table
|
||||
source: Mapped["KGEntityExtractionStaging"] = relationship(
|
||||
"KGEntityExtractionStaging", foreign_keys=[source_node]
|
||||
)
|
||||
target: Mapped["KGEntityExtractionStaging"] = relationship(
|
||||
"KGEntityExtractionStaging", foreign_keys=[target_node]
|
||||
)
|
||||
document: Mapped["Document"] = relationship(
|
||||
"Document", foreign_keys=[source_document]
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Composite primary key
|
||||
PrimaryKeyConstraint("id_name", "source_document"),
|
||||
# Index for querying relationships by type
|
||||
Index("ix_kg_relationship_type", type),
|
||||
# Composite index for source/target queries
|
||||
Index("ix_kg_relationship_nodes", source_node, target_node),
|
||||
# Ensure unique relationships between nodes of a specific type
|
||||
UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_source_target_type",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class KGTerm(Base):
|
||||
__tablename__ = "kg_term"
|
||||
|
||||
# Make id_term the primary key
|
||||
id_term: Mapped[str] = mapped_column(
|
||||
NullFilteredString, primary_key=True, nullable=False, index=True
|
||||
)
|
||||
|
||||
# List of entity types this term applies to
|
||||
entity_types: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Index for searching terms with specific entity types
|
||||
Index("ix_search_term_entities", entity_types),
|
||||
# Index for term lookups
|
||||
Index("ix_search_term_term", id_term),
|
||||
)
|
||||
|
||||
|
||||
class ChunkStats(Base):
|
||||
__tablename__ = "chunk_stats"
|
||||
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
|
||||
|
||||
# this should correspond to the ID of the document
|
||||
# (as is passed around in Onyx)
|
||||
# (as is passed around in Onyx)x
|
||||
id: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
primary_key=True,
|
||||
@@ -691,6 +1331,16 @@ class Connector(Base):
|
||||
indexing_start: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime, nullable=True
|
||||
)
|
||||
|
||||
kg_processing_enabled: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
comment="Whether this connector should extract knowledge graph entities",
|
||||
)
|
||||
|
||||
kg_coverage_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
|
||||
590
backend/onyx/db/relationships.py
Normal file
590
backend/onyx/db/relationships.py
Normal file
@@ -0,0 +1,590 @@
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import onyx.db.document as dbdocument
|
||||
from onyx.db.models import KGEntity
|
||||
from onyx.db.models import KGEntityExtractionStaging
|
||||
from onyx.db.models import KGRelationship
|
||||
from onyx.db.models import KGRelationshipExtractionStaging
|
||||
from onyx.db.models import KGRelationshipType
|
||||
from onyx.db.models import KGRelationshipTypeExtractionStaging
|
||||
from onyx.db.models import KGStage
|
||||
from onyx.kg.utils.formatting_utils import extract_relationship_type_id
|
||||
from onyx.kg.utils.formatting_utils import format_relationship_id
|
||||
from onyx.kg.utils.formatting_utils import get_entity_type
|
||||
from onyx.kg.utils.formatting_utils import make_relationship_id
|
||||
from onyx.kg.utils.formatting_utils import make_relationship_type_id
|
||||
from onyx.kg.utils.formatting_utils import split_relationship_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def upsert_staging_relationship(
|
||||
db_session: Session,
|
||||
relationship_id_name: str,
|
||||
source_document_id: str,
|
||||
occurrences: int = 1,
|
||||
) -> KGRelationshipExtractionStaging:
|
||||
"""
|
||||
Add or update a new staging relationship to the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
relationship_id_name: The ID name of the relationship in format "source__relationship__target"
|
||||
source_document_id: ID of the source document
|
||||
occurrences: Number of times this relationship has been found
|
||||
Returns:
|
||||
The created or updated KGRelationshipExtractionStaging object
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.IntegrityError: If there's an error with the database operation
|
||||
"""
|
||||
# Generate a unique ID for the relationship
|
||||
relationship_id_name = format_relationship_id(relationship_id_name)
|
||||
(
|
||||
source_entity_id_name,
|
||||
relationship_string,
|
||||
target_entity_id_name,
|
||||
) = split_relationship_id(relationship_id_name)
|
||||
|
||||
source_entity_type = get_entity_type(source_entity_id_name)
|
||||
target_entity_type = get_entity_type(target_entity_id_name)
|
||||
relationship_type = extract_relationship_type_id(relationship_id_name)
|
||||
|
||||
# Insert the new relationship
|
||||
stmt = (
|
||||
postgresql.insert(KGRelationshipExtractionStaging)
|
||||
.values(
|
||||
{
|
||||
"id_name": relationship_id_name,
|
||||
"source_node": source_entity_id_name,
|
||||
"target_node": target_entity_id_name,
|
||||
"source_node_type": source_entity_type,
|
||||
"target_node_type": target_entity_type,
|
||||
"type": relationship_string.lower(),
|
||||
"relationship_type_id_name": relationship_type,
|
||||
"source_document": source_document_id,
|
||||
"occurrences": occurrences,
|
||||
}
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["id_name", "source_document"],
|
||||
set_=dict(
|
||||
occurrences=KGRelationshipExtractionStaging.occurrences + occurrences,
|
||||
),
|
||||
)
|
||||
.returning(KGRelationshipExtractionStaging)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt).scalar()
|
||||
if result is None:
|
||||
raise RuntimeError(
|
||||
f"Failed to create or increment staging relationship with id_name: {relationship_id_name}"
|
||||
)
|
||||
|
||||
# Update the document's kg_stage if source_document is provided
|
||||
if source_document_id is not None:
|
||||
dbdocument.update_document_kg_info(
|
||||
db_session,
|
||||
document_id=source_document_id,
|
||||
kg_stage=KGStage.EXTRACTED,
|
||||
)
|
||||
db_session.flush() # Flush to get any DB errors early
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def transfer_relationship(
|
||||
db_session: Session,
|
||||
relationship: KGRelationshipExtractionStaging,
|
||||
entity_translations: dict[str, str],
|
||||
) -> KGRelationship:
|
||||
"""
|
||||
Transfer a relationship from the staging table to the normalized table.
|
||||
"""
|
||||
# Translate the source and target nodes
|
||||
source_node = entity_translations.get(
|
||||
relationship.source_node, relationship.source_node
|
||||
)
|
||||
target_node = entity_translations.get(
|
||||
relationship.target_node, relationship.target_node
|
||||
)
|
||||
relationship_id_name = make_relationship_id(
|
||||
source_node, relationship.type, target_node
|
||||
)
|
||||
|
||||
# Create the transferred relationship
|
||||
stmt = (
|
||||
pg_insert(KGRelationship)
|
||||
.values(
|
||||
id_name=relationship_id_name,
|
||||
source_node=source_node,
|
||||
target_node=target_node,
|
||||
source_node_type=relationship.source_node_type,
|
||||
target_node_type=relationship.target_node_type,
|
||||
type=relationship.type,
|
||||
relationship_type_id_name=relationship.relationship_type_id_name,
|
||||
source_document=relationship.source_document,
|
||||
occurrences=relationship.occurrences,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["id_name", "source_document"],
|
||||
set_=dict(
|
||||
occurrences=KGRelationship.occurrences + relationship.occurrences,
|
||||
),
|
||||
)
|
||||
.returning(KGRelationship)
|
||||
)
|
||||
|
||||
new_relationship = db_session.execute(stmt).scalar()
|
||||
if new_relationship is None:
|
||||
raise RuntimeError(
|
||||
f"Failed to transfer relationship with id_name: {relationship.id_name}"
|
||||
)
|
||||
|
||||
# Update transferred
|
||||
db_session.query(KGRelationshipExtractionStaging).filter(
|
||||
KGRelationshipExtractionStaging.id_name == relationship.id_name,
|
||||
KGRelationshipExtractionStaging.source_document == relationship.source_document,
|
||||
).update({"transferred": True})
|
||||
db_session.flush()
|
||||
|
||||
return new_relationship
|
||||
|
||||
|
||||
def upsert_staging_relationship_type(
|
||||
db_session: Session,
|
||||
source_entity_type: str,
|
||||
relationship_type: str,
|
||||
target_entity_type: str,
|
||||
definition: bool = False,
|
||||
extraction_count: int = 1,
|
||||
) -> KGRelationshipTypeExtractionStaging:
|
||||
"""
|
||||
Add a new relationship type to the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
source_entity_type: Type of the source entity
|
||||
relationship_type: Type of relationship
|
||||
target_entity_type: Type of the target entity
|
||||
definition: Whether this relationship type represents a definition (default False)
|
||||
|
||||
Returns:
|
||||
The created KGRelationshipTypeExtractionStaging object
|
||||
"""
|
||||
|
||||
id_name = make_relationship_type_id(
|
||||
source_entity_type, relationship_type, target_entity_type
|
||||
)
|
||||
|
||||
# Create new relationship type
|
||||
stmt = (
|
||||
postgresql.insert(KGRelationshipTypeExtractionStaging)
|
||||
.values(
|
||||
{
|
||||
"id_name": id_name,
|
||||
"name": relationship_type,
|
||||
"source_entity_type_id_name": source_entity_type.upper(),
|
||||
"target_entity_type_id_name": target_entity_type.upper(),
|
||||
"definition": definition,
|
||||
"occurrences": extraction_count,
|
||||
"type": relationship_type, # Using the relationship_type as the type
|
||||
"active": True, # Setting as active by default
|
||||
}
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["id_name"],
|
||||
set_=dict(
|
||||
occurrences=KGRelationshipTypeExtractionStaging.occurrences
|
||||
+ extraction_count,
|
||||
),
|
||||
)
|
||||
.returning(KGRelationshipTypeExtractionStaging)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt).scalar()
|
||||
if result is None:
|
||||
raise RuntimeError(
|
||||
f"Failed to create or increment staging relationship type with id_name: {id_name}"
|
||||
)
|
||||
db_session.flush() # Flush to get any DB errors early
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def transfer_relationship_type(
|
||||
db_session: Session,
|
||||
relationship_type: KGRelationshipTypeExtractionStaging,
|
||||
) -> KGRelationshipType:
|
||||
"""
|
||||
Transfer a relationship type from the staging table to the normalized table.
|
||||
"""
|
||||
stmt = (
|
||||
pg_insert(KGRelationshipType)
|
||||
.values(
|
||||
id_name=relationship_type.id_name,
|
||||
name=relationship_type.name,
|
||||
source_entity_type_id_name=relationship_type.source_entity_type_id_name,
|
||||
target_entity_type_id_name=relationship_type.target_entity_type_id_name,
|
||||
definition=relationship_type.definition,
|
||||
occurrences=relationship_type.occurrences,
|
||||
type=relationship_type.type,
|
||||
active=relationship_type.active,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["id_name"],
|
||||
set_=dict(
|
||||
occurrences=KGRelationshipType.occurrences
|
||||
+ relationship_type.occurrences,
|
||||
),
|
||||
)
|
||||
.returning(KGRelationshipType)
|
||||
)
|
||||
|
||||
new_relationship_type = db_session.execute(stmt).scalar()
|
||||
if new_relationship_type is None:
|
||||
raise RuntimeError(
|
||||
f"Failed to transfer relationship type with id_name: {relationship_type.id_name}"
|
||||
)
|
||||
|
||||
# Update transferred
|
||||
db_session.query(KGRelationshipTypeExtractionStaging).filter(
|
||||
KGRelationshipTypeExtractionStaging.id_name == relationship_type.id_name
|
||||
).update({"transferred": True})
|
||||
db_session.flush()
|
||||
|
||||
return new_relationship_type
|
||||
|
||||
|
||||
def get_parent_child_relationships_and_types(
|
||||
db_session: Session,
|
||||
depth: int,
|
||||
) -> tuple[
|
||||
list[KGRelationshipExtractionStaging], list[KGRelationshipTypeExtractionStaging]
|
||||
]:
|
||||
"""
|
||||
Create parent-child relationships and relationship types from staging entities with
|
||||
a parent key, if the parent exists in the normalized entities table. Will create
|
||||
relationships up to depth levels. E.g., if depth is 2, a relationship will be created
|
||||
between the entity and its parent, and the entity and its grandparents (if any).
|
||||
A relationship will not be created if the parent does not exist.
|
||||
"""
|
||||
relationship_types: dict[str, KGRelationshipTypeExtractionStaging] = {}
|
||||
relationships: dict[tuple[str, str | None], KGRelationshipExtractionStaging] = {}
|
||||
|
||||
parented_entities = (
|
||||
db_session.query(KGEntityExtractionStaging)
|
||||
.filter(KGEntityExtractionStaging.parent_key.isnot(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
# create has_subcomponent relationships and relationship types
|
||||
for entity in parented_entities:
|
||||
child = entity
|
||||
if entity.transferred_id_name is None:
|
||||
logger.warning(f"Entity {entity.id_name} has not yet been transferred")
|
||||
continue
|
||||
|
||||
for i in range(depth):
|
||||
if not child.parent_key:
|
||||
break
|
||||
|
||||
# find the transferred parent entity
|
||||
parent = (
|
||||
db_session.query(KGEntity)
|
||||
.filter(
|
||||
KGEntity.entity_class == child.entity_class,
|
||||
KGEntity.entity_key == child.parent_key,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if parent is None:
|
||||
logger.warning(f"Parent entity not found for {entity.id_name}")
|
||||
break
|
||||
|
||||
# create the relationship type
|
||||
relationship_type = upsert_staging_relationship_type(
|
||||
db_session=db_session,
|
||||
source_entity_type=parent.entity_type_id_name,
|
||||
relationship_type="has_subcomponent",
|
||||
target_entity_type=entity.entity_type_id_name,
|
||||
definition=False,
|
||||
extraction_count=1,
|
||||
)
|
||||
relationship_types[relationship_type.id_name] = relationship_type
|
||||
|
||||
# create the relationship
|
||||
# (don't add it to the table as we're using the transferred id, which breaks fk constraints)
|
||||
relationship_id_name = make_relationship_id(
|
||||
parent.id_name, "has_subcomponent", entity.transferred_id_name
|
||||
)
|
||||
if (parent.id_name, entity.document_id) not in relationships:
|
||||
(
|
||||
source_entity_id_name,
|
||||
relationship_string,
|
||||
target_entity_id_name,
|
||||
) = split_relationship_id(relationship_id_name)
|
||||
|
||||
source_entity_type = get_entity_type(source_entity_id_name)
|
||||
target_entity_type = get_entity_type(target_entity_id_name)
|
||||
relationship_type_id_name = extract_relationship_type_id(
|
||||
relationship_id_name
|
||||
)
|
||||
relationships[(relationship_id_name, entity.document_id)] = (
|
||||
KGRelationshipExtractionStaging(
|
||||
id_name=relationship_id_name,
|
||||
source_node=source_entity_id_name,
|
||||
target_node=target_entity_id_name,
|
||||
source_node_type=source_entity_type,
|
||||
target_node_type=target_entity_type,
|
||||
type=relationship_string,
|
||||
relationship_type_id_name=relationship_type_id_name,
|
||||
source_document=entity.document_id,
|
||||
occurrences=1,
|
||||
)
|
||||
)
|
||||
else:
|
||||
relationships[(parent.id_name, entity.document_id)].occurrences += 1
|
||||
|
||||
# set parent as the next child (unless we're at the max depth)
|
||||
if i < depth - 1:
|
||||
parent_staging = (
|
||||
db_session.query(KGEntityExtractionStaging)
|
||||
.filter(
|
||||
KGEntityExtractionStaging.transferred_id_name == parent.id_name
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if parent_staging is None:
|
||||
break
|
||||
child = parent_staging
|
||||
|
||||
return list(relationships.values()), list(relationship_types.values())
|
||||
|
||||
|
||||
def delete_relationships_by_id_names(
|
||||
db_session: Session, id_names: list[str], kg_stage: KGStage
|
||||
) -> int:
|
||||
"""
|
||||
Delete relationships from the database based on a list of id_names.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
id_names: List of relationship id_names to delete
|
||||
|
||||
Returns:
|
||||
Number of relationships deleted
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
|
||||
"""
|
||||
|
||||
deleted_count = 0
|
||||
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
deleted_count = (
|
||||
db_session.query(KGRelationshipExtractionStaging)
|
||||
.filter(KGRelationshipExtractionStaging.id_name.in_(id_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
elif kg_stage == KGStage.NORMALIZED:
|
||||
deleted_count = (
|
||||
db_session.query(KGRelationship)
|
||||
.filter(KGRelationship.id_name.in_(id_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
db_session.flush() # Flush to ensure deletion is processed
|
||||
return deleted_count
|
||||
|
||||
|
||||
def delete_relationship_types_by_id_names(
|
||||
db_session: Session, id_names: list[str], kg_stage: KGStage
|
||||
) -> int:
|
||||
"""
|
||||
Delete relationship types from the database based on a list of id_names.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
id_names: List of relationship type id_names to delete
|
||||
|
||||
Returns:
|
||||
Number of relationship types deleted
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
|
||||
"""
|
||||
deleted_count = 0
|
||||
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
deleted_count = (
|
||||
db_session.query(KGRelationshipTypeExtractionStaging)
|
||||
.filter(KGRelationshipTypeExtractionStaging.id_name.in_(id_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
elif kg_stage == KGStage.NORMALIZED:
|
||||
deleted_count = (
|
||||
db_session.query(KGRelationshipType)
|
||||
.filter(KGRelationshipType.id_name.in_(id_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
db_session.flush() # Flush to ensure deletion is processed
|
||||
return deleted_count
|
||||
|
||||
|
||||
def get_relationships_for_entity_type_pairs(
|
||||
db_session: Session, entity_type_pairs: list[tuple[str, str]]
|
||||
) -> list["KGRelationshipType"]:
|
||||
"""
|
||||
Get relationship types from the database based on a list of entity type pairs.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
entity_type_pairs: List of tuples where each tuple contains (source_entity_type, target_entity_type)
|
||||
|
||||
Returns:
|
||||
List of KGRelationshipType objects where source and target types match the provided pairs
|
||||
"""
|
||||
|
||||
conditions = [
|
||||
(
|
||||
(KGRelationshipType.source_entity_type_id_name == source_type)
|
||||
& (KGRelationshipType.target_entity_type_id_name == target_type)
|
||||
)
|
||||
for source_type, target_type in entity_type_pairs
|
||||
]
|
||||
|
||||
return db_session.query(KGRelationshipType).filter(or_(*conditions)).all()
|
||||
|
||||
|
||||
def get_allowed_relationship_type_pairs(
|
||||
db_session: Session, entities: list[str]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get the allowed relationship pairs for the given entities.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
entities: List of entity type ID names to filter by
|
||||
|
||||
Returns:
|
||||
List of id_names from KGRelationshipType where both source and target entity types
|
||||
are in the provided entities list
|
||||
"""
|
||||
entity_types = list({get_entity_type(entity) for entity in entities})
|
||||
|
||||
return [
|
||||
row[0]
|
||||
for row in (
|
||||
db_session.query(KGRelationshipType.id_name)
|
||||
.filter(KGRelationshipType.source_entity_type_id_name.in_(entity_types))
|
||||
.filter(KGRelationshipType.target_entity_type_id_name.in_(entity_types))
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def get_relationships_of_entity(db_session: Session, entity_id: str) -> List[str]:
|
||||
"""Get all relationship ID names where the given entity is either the source or target node.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity_id: ID of the entity to find relationships for
|
||||
|
||||
Returns:
|
||||
List of relationship ID names where the entity is either source or target
|
||||
"""
|
||||
return [
|
||||
row[0]
|
||||
for row in (
|
||||
db_session.query(KGRelationship.id_name)
|
||||
.filter(
|
||||
or_(
|
||||
KGRelationship.source_node == entity_id,
|
||||
KGRelationship.target_node == entity_id,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def get_relationship_types_of_entity_types(
|
||||
db_session: Session, entity_types_id: str
|
||||
) -> List[str]:
|
||||
"""Get all relationship ID names where the given entity is either the source or target node.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity_types_id: ID of the entity to find relationships for
|
||||
|
||||
Returns:
|
||||
List of relationship ID names where the entity is either source or target
|
||||
"""
|
||||
|
||||
if entity_types_id.endswith(":*"):
|
||||
entity_types_id = entity_types_id[:-2]
|
||||
|
||||
return [
|
||||
row[0]
|
||||
for row in (
|
||||
db_session.query(KGRelationshipType.id_name)
|
||||
.filter(
|
||||
or_(
|
||||
KGRelationshipType.source_entity_type_id_name == entity_types_id,
|
||||
KGRelationshipType.target_entity_type_id_name == entity_types_id,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def delete_document_references_from_kg(db_session: Session, document_id: str) -> None:
|
||||
# Delete relationships from normalized stage
|
||||
db_session.query(KGRelationship).filter(
|
||||
KGRelationship.source_document == document_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
# Delete relationships from extraction staging
|
||||
db_session.query(KGRelationshipExtractionStaging).filter(
|
||||
KGRelationshipExtractionStaging.source_document == document_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
# Delete entities from normalized stage
|
||||
db_session.query(KGEntity).filter(KGEntity.document_id == document_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
# Delete entities from extraction staging
|
||||
db_session.query(KGEntityExtractionStaging).filter(
|
||||
KGEntityExtractionStaging.document_id == document_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def delete_from_kg_relationships_extraction_staging__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""Delete relationships from the extraction staging table."""
|
||||
db_session.query(KGRelationshipExtractionStaging).filter(
|
||||
KGRelationshipExtractionStaging.source_document.in_(document_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def delete_from_kg_relationships__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""Delete relationships from the normalized table."""
|
||||
db_session.query(KGRelationship).filter(
|
||||
KGRelationship.source_document.in_(document_ids)
|
||||
).delete(synchronize_session=False)
|
||||
@@ -91,6 +91,25 @@ schema {{ schema_name }} {
|
||||
indexing: attribute
|
||||
}
|
||||
|
||||
# Separate array fields for knowledge graph data
|
||||
field kg_entities type weightedset<string> {
|
||||
rank: filter
|
||||
indexing: summary | attribute
|
||||
attribute: fast-search
|
||||
}
|
||||
|
||||
field kg_relationships type weightedset<string> {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
|
||||
field kg_terms type weightedset<string> {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
|
||||
# Needs to have a separate Attribute list for efficient filtering
|
||||
field metadata_list type array<string> {
|
||||
indexing: summary | attribute
|
||||
|
||||
@@ -166,18 +166,19 @@ def _get_chunks_via_visit_api(
|
||||
|
||||
# build the list of fields to retrieve
|
||||
field_set_list = (
|
||||
None
|
||||
if not field_names
|
||||
else [f"{index_name}:{field_name}" for field_name in field_names]
|
||||
[f"{field_name}" for field_name in field_names] if field_names else []
|
||||
)
|
||||
acl_fieldset_entry = f"{index_name}:{ACCESS_CONTROL_LIST}"
|
||||
acl_fieldset_entry = f"{ACCESS_CONTROL_LIST}"
|
||||
if (
|
||||
field_set_list
|
||||
and filters.access_control_list
|
||||
and acl_fieldset_entry not in field_set_list
|
||||
):
|
||||
field_set_list.append(acl_fieldset_entry)
|
||||
field_set = ",".join(field_set_list) if field_set_list else None
|
||||
if field_set_list:
|
||||
field_set = f"{index_name}:" + ",".join(field_set_list)
|
||||
else:
|
||||
field_set = None
|
||||
|
||||
# build filters
|
||||
selection = f"{index_name}.document_id=='{chunk_request.document_id}'"
|
||||
|
||||
@@ -18,6 +18,7 @@ from uuid import UUID
|
||||
import httpx # type: ignore
|
||||
import jinja2
|
||||
import requests # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
@@ -30,6 +31,7 @@ from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.document_index_utils import get_document_chunk_ids
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
@@ -86,6 +88,17 @@ httpx_logger = logging.getLogger("httpx")
|
||||
httpx_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def update_kg_type_dict(
|
||||
dict_to_update: dict[str, dict], kg_type: str, value_set: set[str]
|
||||
) -> dict[str, dict]:
|
||||
if "fields" not in dict_to_update:
|
||||
dict_to_update["fields"] = {}
|
||||
dict_to_update["fields"][kg_type] = {
|
||||
"assign": {kg_type_object: 1 for kg_type_object in value_set}
|
||||
}
|
||||
return dict_to_update
|
||||
|
||||
|
||||
@dataclass
|
||||
class _VespaUpdateRequest:
|
||||
document_id: str
|
||||
@@ -93,6 +106,39 @@ class _VespaUpdateRequest:
|
||||
update_request: dict[str, dict]
|
||||
|
||||
|
||||
class KGVespaChunkUpdateRequest(BaseModel):
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
url: str
|
||||
update_request: dict[str, dict]
|
||||
|
||||
|
||||
class KGUChunkUpdateRequest(BaseModel):
|
||||
"""
|
||||
Update KG fields for a document
|
||||
"""
|
||||
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
core_entity: str
|
||||
entities: set[str] | None = None
|
||||
relationships: set[str] | None = None
|
||||
terms: set[str] | None = None
|
||||
converted_attributes: set[str] | None = None
|
||||
attributes: dict[str, str | list[str]] | None = None
|
||||
|
||||
|
||||
class KGUDocumentUpdateRequest(BaseModel):
|
||||
"""
|
||||
Update KG fields for a document
|
||||
"""
|
||||
|
||||
document_id: str
|
||||
entities: set[str]
|
||||
relationships: set[str]
|
||||
terms: set[str]
|
||||
|
||||
|
||||
def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
@@ -501,6 +547,51 @@ class VespaIndex(DocumentIndex):
|
||||
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
@classmethod
|
||||
def _apply_kg_chunk_updates_batched(
|
||||
cls,
|
||||
updates: list[KGVespaChunkUpdateRequest],
|
||||
httpx_client: httpx.Client,
|
||||
batch_size: int = BATCH_SIZE,
|
||||
) -> None:
|
||||
"""Runs a batch of updates in parallel via the ThreadPoolExecutor."""
|
||||
|
||||
def _kg_update_chunk(
|
||||
update: KGVespaChunkUpdateRequest, http_client: httpx.Client
|
||||
) -> httpx.Response:
|
||||
# logger.debug(
|
||||
# f"Updating KG with request to {update.url} with body {update.update_request}"
|
||||
# )
|
||||
return http_client.put(
|
||||
update.url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update.update_request,
|
||||
)
|
||||
|
||||
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficient for
|
||||
# indexing / updates / deletes since we have to make a large volume of requests.
|
||||
|
||||
with (
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
|
||||
httpx_client as http_client,
|
||||
):
|
||||
for update_batch in batch_generator(updates, batch_size):
|
||||
future_to_document_id = {
|
||||
executor.submit(
|
||||
_kg_update_chunk,
|
||||
update,
|
||||
http_client,
|
||||
): update.document_id
|
||||
for update in update_batch
|
||||
}
|
||||
for future in concurrent.futures.as_completed(future_to_document_id):
|
||||
res = future.result()
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
|
||||
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
|
||||
|
||||
@@ -584,6 +675,63 @@ class VespaIndex(DocumentIndex):
|
||||
time.monotonic() - update_start,
|
||||
)
|
||||
|
||||
def kg_chunk_updates(
|
||||
self, kg_update_requests: list[KGUChunkUpdateRequest], tenant_id: str
|
||||
) -> None:
|
||||
|
||||
processed_updates_requests: list[KGVespaChunkUpdateRequest] = []
|
||||
update_start = time.monotonic()
|
||||
|
||||
# Build the _VespaUpdateRequest objects
|
||||
|
||||
for kg_update_request in kg_update_requests:
|
||||
kg_update_dict: dict[str, dict] = {"fields": {}}
|
||||
|
||||
if kg_update_request.relationships is not None:
|
||||
kg_update_dict = update_kg_type_dict(
|
||||
kg_update_dict, "kg_relationships", kg_update_request.relationships
|
||||
)
|
||||
|
||||
if kg_update_request.entities is not None:
|
||||
kg_update_dict = update_kg_type_dict(
|
||||
kg_update_dict, "kg_entities", kg_update_request.entities
|
||||
)
|
||||
|
||||
if kg_update_request.terms is not None:
|
||||
kg_update_dict = update_kg_type_dict(
|
||||
kg_update_dict, "kg_terms", kg_update_request.terms
|
||||
)
|
||||
|
||||
if not kg_update_dict["fields"]:
|
||||
logger.error("Update request received but nothing to update")
|
||||
continue
|
||||
|
||||
doc_chunk_id = get_uuid_from_chunk_info(
|
||||
document_id=kg_update_request.document_id,
|
||||
chunk_id=kg_update_request.chunk_id,
|
||||
tenant_id=tenant_id,
|
||||
large_chunk_id=None,
|
||||
)
|
||||
|
||||
processed_updates_requests.append(
|
||||
KGVespaChunkUpdateRequest(
|
||||
document_id=kg_update_request.document_id,
|
||||
chunk_id=kg_update_request.chunk_id,
|
||||
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
|
||||
update_request=kg_update_dict,
|
||||
)
|
||||
)
|
||||
|
||||
with self.httpx_client_context as httpx_client:
|
||||
self._apply_kg_chunk_updates_batched(
|
||||
processed_updates_requests, httpx_client
|
||||
)
|
||||
logger.debug(
|
||||
"Updated %d vespa documents in %.2f seconds",
|
||||
len(processed_updates_requests),
|
||||
time.monotonic() - update_start,
|
||||
)
|
||||
|
||||
@retry(
|
||||
tries=3,
|
||||
delay=1,
|
||||
|
||||
78
backend/onyx/document_index/vespa/kg_interactions.py
Normal file
78
backend/onyx/document_index/vespa/kg_interactions.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from retry import retry
|
||||
|
||||
from onyx.db.document import get_document_kg_entities_and_relationships
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
|
||||
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
|
||||
from onyx.document_index.vespa.index import IndexFilters
|
||||
from onyx.document_index.vespa.index import KGUChunkUpdateRequest
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.kg.utils.formatting_utils import generalize_entities
|
||||
from onyx.kg.utils.formatting_utils import generalize_relationships
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def update_kg_chunks_vespa_info(
|
||||
kg_update_requests: list[KGUChunkUpdateRequest],
|
||||
index_name: str,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
""" """
|
||||
# Use the existing visit API infrastructure
|
||||
vespa_index = VespaIndex(
|
||||
index_name=index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=False,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=None,
|
||||
)
|
||||
|
||||
vespa_index.kg_chunk_updates(
|
||||
kg_update_requests=kg_update_requests, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
|
||||
def get_kg_vespa_info_update_requests_for_document(
|
||||
document_id: str, index_name: str, tenant_id: str
|
||||
) -> list[KGUChunkUpdateRequest]:
|
||||
"""Get the kg_info update requests for a document."""
|
||||
# get all entities and relationships tied to the document
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entities, relationships = get_document_kg_entities_and_relationships(
|
||||
db_session, document_id
|
||||
)
|
||||
|
||||
# create the kg vespa info
|
||||
entity_id_names = [entity.id_name for entity in entities]
|
||||
relationship_id_names = [relationship.id_name for relationship in relationships]
|
||||
|
||||
kg_entities = generalize_entities(entity_id_names) | set(entity_id_names)
|
||||
kg_relationships = generalize_relationships(relationship_id_names) | set(
|
||||
relationship_id_names
|
||||
)
|
||||
|
||||
# get chunks in the document
|
||||
chunks = _get_chunks_via_visit_api(
|
||||
chunk_request=VespaChunkRequest(document_id=document_id),
|
||||
index_name=index_name,
|
||||
filters=IndexFilters(access_control_list=None, tenant_id=tenant_id),
|
||||
field_names=["chunk_id"],
|
||||
get_large_chunks=False,
|
||||
)
|
||||
|
||||
# get vespa update requests
|
||||
return [
|
||||
KGUChunkUpdateRequest(
|
||||
document_id=document_id,
|
||||
chunk_id=chunk["fields"]["chunk_id"],
|
||||
core_entity="unused",
|
||||
entities=kg_entities,
|
||||
relationships=kg_relationships or None,
|
||||
)
|
||||
for chunk in chunks
|
||||
]
|
||||
@@ -54,6 +54,47 @@ def build_vespa_filters(
|
||||
|
||||
return result
|
||||
|
||||
def _build_kg_filter(
|
||||
kg_entities: list[str] | None,
|
||||
kg_relationships: list[str] | None,
|
||||
kg_terms: list[str] | None,
|
||||
) -> str:
|
||||
if not kg_entities and not kg_relationships and not kg_terms:
|
||||
return ""
|
||||
|
||||
filter_parts = []
|
||||
|
||||
# Process each filter type using the same pattern
|
||||
for filter_type, values in [
|
||||
("kg_entities", kg_entities),
|
||||
("kg_relationships", kg_relationships),
|
||||
("kg_terms", kg_terms),
|
||||
]:
|
||||
if values:
|
||||
filter_parts.append(
|
||||
" and ".join(f'({filter_type} contains "{val}") ' for val in values)
|
||||
)
|
||||
|
||||
return f"({' and '.join(filter_parts)}) and "
|
||||
|
||||
def _build_kg_source_filters(
|
||||
kg_sources: list[str] | None,
|
||||
) -> str:
|
||||
if not kg_sources:
|
||||
return ""
|
||||
|
||||
source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources]
|
||||
|
||||
return f"({' or '.join(source_phrases)}) and "
|
||||
|
||||
def _build_kg_chunk_id_zero_only_filter(
|
||||
kg_chunk_id_zero_only: bool,
|
||||
) -> str:
|
||||
if not kg_chunk_id_zero_only:
|
||||
return ""
|
||||
|
||||
return "(chunk_id = 0 ) and "
|
||||
|
||||
def _build_time_filter(
|
||||
cutoff: datetime | None,
|
||||
untimed_doc_cutoff: timedelta = timedelta(days=92),
|
||||
@@ -71,7 +112,9 @@ def build_vespa_filters(
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
# If running in multi-tenant mode
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
if MULTI_TENANT:
|
||||
if not filters.tenant_id:
|
||||
raise ValueError("Tenant ID is required in multi-tenant mode")
|
||||
filter_str += f'({TENANT_ID} contains "{filters.tenant_id}") and '
|
||||
|
||||
# ACL filters
|
||||
@@ -106,6 +149,19 @@ def build_vespa_filters(
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
# Knowledge Graph Filters
|
||||
filter_str += _build_kg_filter(
|
||||
kg_entities=filters.kg_entities,
|
||||
kg_relationships=filters.kg_relationships,
|
||||
kg_terms=filters.kg_terms,
|
||||
)
|
||||
|
||||
filter_str += _build_kg_source_filters(filters.kg_sources)
|
||||
|
||||
filter_str += _build_kg_chunk_id_zero_only_filter(
|
||||
filters.kg_chunk_id_zero_only or False
|
||||
)
|
||||
|
||||
# Trim trailing " and "
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
|
||||
279
backend/onyx/kg/clustering/clustering.py
Normal file
279
backend/onyx/kg/clustering/clustering.py
Normal file
@@ -0,0 +1,279 @@
|
||||
from typing import cast
|
||||
|
||||
from rapidfuzz.fuzz import ratio
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.configs.kg_configs import KG_CLUSTERING_RETRIEVE_THRESHOLD
|
||||
from onyx.configs.kg_configs import KG_CLUSTERING_THRESHOLD
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import KGEntity
|
||||
from onyx.db.entities import KGEntityExtractionStaging
|
||||
from onyx.db.entities import merge_entities
|
||||
from onyx.db.entities import transfer_entity
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import KGEntityType
|
||||
from onyx.db.models import KGRelationshipExtractionStaging
|
||||
from onyx.db.models import KGRelationshipTypeExtractionStaging
|
||||
from onyx.db.relationships import get_parent_child_relationships_and_types
|
||||
from onyx.db.relationships import transfer_relationship
|
||||
from onyx.db.relationships import transfer_relationship_type
|
||||
from onyx.document_index.vespa.kg_interactions import (
|
||||
get_kg_vespa_info_update_requests_for_document,
|
||||
)
|
||||
from onyx.document_index.vespa.kg_interactions import update_kg_chunks_vespa_info
|
||||
from onyx.kg.models import KGGroundingType
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _cluster_one_grounded_entity(
|
||||
entity: KGEntityExtractionStaging,
|
||||
) -> tuple[KGEntity, bool]:
|
||||
"""
|
||||
Cluster a single grounded entity.
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# get entity name and filtering conditions
|
||||
if entity.document_id is not None:
|
||||
entity_name = cast(
|
||||
str,
|
||||
db_session.query(Document.semantic_id)
|
||||
.filter(Document.id == entity.document_id)
|
||||
.scalar(),
|
||||
).lower()
|
||||
filtering = [KGEntity.document_id.is_(None)]
|
||||
else:
|
||||
entity_name = entity.name.lower()
|
||||
filtering = []
|
||||
|
||||
# skip those with numbers so we don't cluster version1 and version2, etc.
|
||||
similar_entities: list[KGEntity] = []
|
||||
if not any(char.isdigit() for char in entity_name):
|
||||
# find similar entities, uses GIN index, very efficient
|
||||
db_session.execute(
|
||||
text(
|
||||
"SET pg_trgm.similarity_threshold = "
|
||||
+ str(KG_CLUSTERING_RETRIEVE_THRESHOLD)
|
||||
)
|
||||
)
|
||||
similar_entities = (
|
||||
db_session.query(KGEntity)
|
||||
.filter(
|
||||
# find entities of the same type with a similar name
|
||||
*filtering,
|
||||
KGEntity.entity_type_id_name == entity.entity_type_id_name,
|
||||
KGEntity.name.op("%")(entity_name),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# find best match
|
||||
best_score = -1.0
|
||||
best_entity = None
|
||||
for similar in similar_entities:
|
||||
# skip those with numbers so we don't cluster version1 and version2, etc.
|
||||
if any(char.isdigit() for char in similar.name):
|
||||
continue
|
||||
score = ratio(similar.name, entity_name)
|
||||
if score >= KG_CLUSTERING_THRESHOLD * 100 and score > best_score:
|
||||
best_score = score
|
||||
best_entity = similar
|
||||
|
||||
# if there is a match, update the entity, otherwise create a new one
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if best_entity:
|
||||
logger.debug(f"Merged {entity.name} with {best_entity.name}")
|
||||
update_vespa = (
|
||||
best_entity.document_id is None and entity.document_id is not None
|
||||
)
|
||||
transferred_entity = merge_entities(
|
||||
db_session=db_session, parent=best_entity, child=entity
|
||||
)
|
||||
else:
|
||||
update_vespa = entity.document_id is not None
|
||||
transferred_entity = transfer_entity(db_session=db_session, entity=entity)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return transferred_entity, update_vespa
|
||||
|
||||
|
||||
def _transfer_batch_relationship(
|
||||
relationships: list[KGRelationshipExtractionStaging],
|
||||
entity_translations: dict[str, str],
|
||||
) -> set[str]:
|
||||
updated_documents: set[str] = set()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entity_id_names: set[str] = set()
|
||||
for relationship in relationships:
|
||||
transferred_relationship = transfer_relationship(
|
||||
db_session=db_session,
|
||||
relationship=relationship,
|
||||
entity_translations=entity_translations,
|
||||
)
|
||||
entity_id_names.add(transferred_relationship.source_node)
|
||||
entity_id_names.add(transferred_relationship.target_node)
|
||||
|
||||
updated_documents.update(
|
||||
(
|
||||
res[0]
|
||||
for res in db_session.query(KGEntity.document_id)
|
||||
.filter(KGEntity.id_name.in_(entity_id_names))
|
||||
.all()
|
||||
if res[0] is not None
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
return updated_documents
|
||||
|
||||
|
||||
def kg_clustering(
|
||||
tenant_id: str, index_name: str, processing_chunk_batch_size: int = 16
|
||||
) -> None:
|
||||
"""
|
||||
Here we will cluster the extractions based on their cluster frameworks.
|
||||
Initially, this will only focus on grounded entities with pre-determined
|
||||
relationships, so 'clustering' is actually not yet required.
|
||||
However, we may need to reconcile entities coming from different sources.
|
||||
|
||||
The primary purpose of this function is to populate the actual KG tables
|
||||
from the temp_extraction tables.
|
||||
|
||||
This will change with deep extraction, where grounded-sourceless entities
|
||||
can be extracted and then need to be clustered.
|
||||
"""
|
||||
|
||||
logger.info(f"Starting kg clustering for tenant {tenant_id}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
kg_config_settings = get_kg_config_settings(db_session)
|
||||
|
||||
# Retrieve staging data
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
untransferred_relationship_types = (
|
||||
db_session.query(KGRelationshipTypeExtractionStaging)
|
||||
.filter(KGRelationshipTypeExtractionStaging.transferred.is_(False))
|
||||
.all()
|
||||
)
|
||||
untransferred_relationships = (
|
||||
db_session.query(KGRelationshipExtractionStaging)
|
||||
.filter(KGRelationshipExtractionStaging.transferred.is_(False))
|
||||
.all()
|
||||
)
|
||||
grounded_entities = (
|
||||
db_session.query(KGEntityExtractionStaging)
|
||||
.join(
|
||||
KGEntityType,
|
||||
KGEntityExtractionStaging.entity_type_id_name == KGEntityType.id_name,
|
||||
)
|
||||
.filter(KGEntityType.grounding == KGGroundingType.GROUNDED)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Cluster and transfer grounded entities
|
||||
untransferred_grounded_entities = [
|
||||
entity for entity in grounded_entities if entity.transferred_id_name is None
|
||||
]
|
||||
entity_translations: dict[str, str] = {
|
||||
entity.id_name: entity.transferred_id_name
|
||||
for entity in grounded_entities
|
||||
if entity.transferred_id_name is not None
|
||||
}
|
||||
vespa_update_documents: set[str] = set()
|
||||
|
||||
for entity in untransferred_grounded_entities:
|
||||
added_entity, update_vespa = _cluster_one_grounded_entity(entity)
|
||||
entity_translations[entity.id_name] = added_entity.id_name
|
||||
if update_vespa and added_entity.document_id is not None:
|
||||
vespa_update_documents.add(added_entity.document_id)
|
||||
logger.info(f"Transferred {len(untransferred_grounded_entities)} entities")
|
||||
|
||||
# Add parent-child relationships and relationship types
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
parent_child_relationships, parent_child_relationship_types = (
|
||||
get_parent_child_relationships_and_types(
|
||||
db_session, depth=kg_config_settings.KG_MAX_PARENT_RECURSION_DEPTH
|
||||
)
|
||||
)
|
||||
untransferred_relationship_types.extend(parent_child_relationship_types)
|
||||
untransferred_relationships.extend(parent_child_relationships)
|
||||
db_session.commit()
|
||||
|
||||
# Transfer the relationship types
|
||||
for relationship_type in untransferred_relationship_types:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
transfer_relationship_type(db_session, relationship_type=relationship_type)
|
||||
db_session.commit()
|
||||
logger.info(
|
||||
f"Transferred {len(untransferred_relationship_types)} relationship types"
|
||||
)
|
||||
|
||||
# Transfer relationships in parallel
|
||||
updated_documents_batch: list[set[str]] = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(
|
||||
_transfer_batch_relationship,
|
||||
(
|
||||
untransferred_relationships[
|
||||
batch_i : batch_i + processing_chunk_batch_size
|
||||
],
|
||||
entity_translations,
|
||||
),
|
||||
)
|
||||
for batch_i in range(
|
||||
0, len(untransferred_relationships), processing_chunk_batch_size
|
||||
)
|
||||
]
|
||||
)
|
||||
for updated_documents in updated_documents_batch:
|
||||
vespa_update_documents.update(updated_documents)
|
||||
logger.info(f"Transferred {len(untransferred_relationships)} relationships")
|
||||
|
||||
# Update vespa for documents that had their kg info updated in parallel
|
||||
for i in range(0, len(vespa_update_documents), processing_chunk_batch_size):
|
||||
batch_update_requests = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(
|
||||
get_kg_vespa_info_update_requests_for_document,
|
||||
(document_id, index_name, tenant_id),
|
||||
)
|
||||
for document_id in list(vespa_update_documents)[
|
||||
i : i + processing_chunk_batch_size
|
||||
]
|
||||
]
|
||||
)
|
||||
for update_requests in batch_update_requests:
|
||||
update_kg_chunks_vespa_info(update_requests, index_name, tenant_id)
|
||||
|
||||
# Delete the transferred objects from the staging tables
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.query(KGRelationshipExtractionStaging).filter(
|
||||
KGRelationshipExtractionStaging.transferred.is_(True)
|
||||
).delete(synchronize_session=False)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relationships: {e}")
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.query(KGRelationshipTypeExtractionStaging).filter(
|
||||
KGRelationshipTypeExtractionStaging.transferred.is_(True)
|
||||
).delete(synchronize_session=False)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relationship types: {e}")
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.query(KGEntityExtractionStaging).filter(
|
||||
KGEntityExtractionStaging.transferred_id_name.is_not(None)
|
||||
).delete(synchronize_session=False)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting entities: {e}")
|
||||
403
backend/onyx/kg/clustering/normalizations.py
Normal file
403
backend/onyx/kg/clustering/normalizations.py
Normal file
@@ -0,0 +1,403 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
from nltk import ngrams # type: ignore
|
||||
from rapidfuzz.distance.DamerauLevenshtein import normalized_similarity
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import Float
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
from onyx.configs.kg_configs import KG_NORMALIZATION_RERANK_LEVENSHTEIN_WEIGHT
|
||||
from onyx.configs.kg_configs import KG_NORMALIZATION_RERANK_NGRAM_WEIGHTS
|
||||
from onyx.configs.kg_configs import KG_NORMALIZATION_RERANK_THRESHOLD
|
||||
from onyx.configs.kg_configs import KG_NORMALIZATION_RETRIEVE_ENTITIES_LIMIT
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KGEntity
|
||||
from onyx.db.relationships import get_relationships_for_entity_type_pairs
|
||||
from onyx.kg.models import NormalizedEntities
|
||||
from onyx.kg.models import NormalizedRelationships
|
||||
from onyx.kg.models import NormalizedTerms
|
||||
from onyx.kg.utils.embeddings import encode_string_batch
|
||||
from onyx.kg.utils.formatting_utils import format_entity_id_for_models
|
||||
from onyx.kg.utils.formatting_utils import get_entity_type
|
||||
from onyx.kg.utils.formatting_utils import make_relationship_id
|
||||
from onyx.kg.utils.formatting_utils import split_entity_id
|
||||
from onyx.kg.utils.formatting_utils import split_relationship_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
alphanum_regex = re.compile(r"[^a-z0-9]+")
|
||||
rem_email_regex = re.compile(r"(?<=\S)@([a-z0-9-]+)\.([a-z]{2,6})$")
|
||||
|
||||
|
||||
def _clean_name(entity_name: str) -> str:
|
||||
"""
|
||||
Clean an entity string by removing non-alphanumeric characters and email addresses.
|
||||
If the name after cleaning is empty, return the original name in lowercase.
|
||||
"""
|
||||
cleaned_entity = entity_name.casefold()
|
||||
return (
|
||||
alphanum_regex.sub("", rem_email_regex.sub("", cleaned_entity))
|
||||
or cleaned_entity
|
||||
)
|
||||
|
||||
|
||||
def _normalize_one_entity(
|
||||
entity: str, allowed_docs_temp_view_name: str | None = None
|
||||
) -> str | None:
|
||||
"""
|
||||
Matches a single entity to the best matching entity of the same type.
|
||||
"""
|
||||
entity_type, entity_name = split_entity_id(entity)
|
||||
if entity_name == "*":
|
||||
return entity
|
||||
|
||||
cleaned_entity = _clean_name(entity_name)
|
||||
|
||||
# step 1: find entities containing the entity_name or something similar
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
# get allowed documents
|
||||
|
||||
metadata = MetaData()
|
||||
if allowed_docs_temp_view_name is None:
|
||||
raise ValueError("allowed_docs_temp_view_name is not available")
|
||||
allowed_docs_temp_view = Table(
|
||||
allowed_docs_temp_view_name,
|
||||
metadata,
|
||||
autoload_with=db_session.get_bind(),
|
||||
)
|
||||
|
||||
# generate trigrams of the queried entity Q
|
||||
query_trigrams = db_session.query(
|
||||
func.show_trgm(cleaned_entity).cast(ARRAY(String(3))).label("trigrams")
|
||||
).cte("query")
|
||||
|
||||
candidates = cast(
|
||||
list[tuple[str, str, float]],
|
||||
db_session.query(
|
||||
KGEntity.id_name,
|
||||
KGEntity.name,
|
||||
(
|
||||
# for each entity E, compute score = | Q ∩ E | / min(|Q|, |E|)
|
||||
func.cardinality(
|
||||
func.array(
|
||||
select(func.unnest(KGEntity.name_trigrams))
|
||||
.correlate(KGEntity)
|
||||
.intersect(
|
||||
select(
|
||||
func.unnest(query_trigrams.c.trigrams)
|
||||
).correlate(query_trigrams)
|
||||
)
|
||||
.scalar_subquery()
|
||||
)
|
||||
).cast(Float)
|
||||
/ func.least(
|
||||
func.cardinality(query_trigrams.c.trigrams),
|
||||
func.cardinality(KGEntity.name_trigrams),
|
||||
)
|
||||
).label("score"),
|
||||
)
|
||||
.select_from(KGEntity, query_trigrams)
|
||||
.outerjoin(
|
||||
allowed_docs_temp_view,
|
||||
KGEntity.document_id == allowed_docs_temp_view.c.allowed_doc_id,
|
||||
)
|
||||
.filter(
|
||||
KGEntity.entity_type_id_name == entity_type,
|
||||
KGEntity.name_trigrams.overlap(query_trigrams.c.trigrams),
|
||||
# Add filter for allowed docs - either document_id is NULL or it's in allowed_docs
|
||||
(
|
||||
KGEntity.document_id.is_(None)
|
||||
| allowed_docs_temp_view.c.allowed_doc_id.isnot(None)
|
||||
),
|
||||
)
|
||||
.order_by(desc("score"))
|
||||
.limit(KG_NORMALIZATION_RETRIEVE_ENTITIES_LIMIT)
|
||||
.all(),
|
||||
)
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# step 2: do a weighted ngram analysis and damerau levenshtein distance to rerank
|
||||
n1, n2, n3 = (
|
||||
set(ngrams(cleaned_entity, 1)),
|
||||
set(ngrams(cleaned_entity, 2)),
|
||||
set(ngrams(cleaned_entity, 3)),
|
||||
)
|
||||
for i, (candidate_id_name, candidate_name, _) in enumerate(candidates):
|
||||
cleaned_candidate = _clean_name(candidate_name)
|
||||
h_n1, h_n2, h_n3 = (
|
||||
set(ngrams(cleaned_candidate, 1)),
|
||||
set(ngrams(cleaned_candidate, 2)),
|
||||
set(ngrams(cleaned_candidate, 3)),
|
||||
)
|
||||
|
||||
# compute ngram overlap, renormalize scores if the names are too short for larger ngrams
|
||||
grams_used = min(2, len(cleaned_entity) - 1, len(cleaned_candidate) - 1)
|
||||
W_n1, W_n2, W_n3 = KG_NORMALIZATION_RERANK_NGRAM_WEIGHTS
|
||||
ngram_score = (
|
||||
# compute | Q ∩ E | / min(|Q|, |E|) for unigrams and bigrams (trigrams already computed)
|
||||
W_n1 * len(n1 & h_n1) / max(1, min(len(n1), len(h_n1)))
|
||||
+ W_n2 * len(n2 & h_n2) / max(1, min(len(n2), len(h_n2)))
|
||||
+ W_n3 * len(n3 & h_n3) / max(1, min(len(n3), len(h_n3)))
|
||||
) / (W_n1, W_n1 + W_n2, 1.0)[grams_used]
|
||||
|
||||
# compute damerau levenshtein distance to fuzzy match against typos
|
||||
W_leven = KG_NORMALIZATION_RERANK_LEVENSHTEIN_WEIGHT
|
||||
leven_score = normalized_similarity(cleaned_entity, cleaned_candidate)
|
||||
|
||||
# combine scores
|
||||
score = (1.0 - W_leven) * ngram_score + W_leven * leven_score
|
||||
candidates[i] = (candidate_id_name, candidate_name, score)
|
||||
candidates = list(
|
||||
sorted(
|
||||
filter(lambda x: x[2] > KG_NORMALIZATION_RERANK_THRESHOLD, candidates),
|
||||
key=lambda x: x[2],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
return candidates[0][0]
|
||||
|
||||
|
||||
def _get_existing_normalized_relationships(
|
||||
raw_relationships: list[str],
|
||||
) -> dict[str, dict[str, list[str]]]:
|
||||
"""
|
||||
Get existing normalized relationships from the database.
|
||||
"""
|
||||
|
||||
relationship_type_map: dict[str, dict[str, list[str]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
relationship_pairs = list(
|
||||
{
|
||||
(
|
||||
get_entity_type(split_relationship_id(relationship)[0]),
|
||||
get_entity_type(split_relationship_id(relationship)[2]),
|
||||
)
|
||||
for relationship in raw_relationships
|
||||
}
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
relationships = get_relationships_for_entity_type_pairs(
|
||||
db_session, relationship_pairs
|
||||
)
|
||||
|
||||
for relationship in relationships:
|
||||
relationship_type_map[relationship.source_entity_type_id_name][
|
||||
relationship.target_entity_type_id_name
|
||||
].append(relationship.id_name)
|
||||
|
||||
return relationship_type_map
|
||||
|
||||
|
||||
def normalize_entities(
|
||||
raw_entities_no_attributes: list[str],
|
||||
allowed_docs_temp_view_name: str | None = None,
|
||||
) -> NormalizedEntities:
|
||||
"""
|
||||
Match each entity against a list of normalized entities using fuzzy matching.
|
||||
Returns the best matching normalized entity for each input entity.
|
||||
|
||||
Args:
|
||||
raw_entities_no_attributes: list of entity strings to normalize, w/o attributes
|
||||
|
||||
Returns:
|
||||
list of normalized entity strings
|
||||
"""
|
||||
normalized_results: list[str] = []
|
||||
normalized_map: dict[str, str] = {}
|
||||
|
||||
mapping: list[str | None] = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(_normalize_one_entity, (entity, allowed_docs_temp_view_name))
|
||||
for entity in raw_entities_no_attributes
|
||||
]
|
||||
)
|
||||
for entity, normalized_entity in zip(raw_entities_no_attributes, mapping):
|
||||
if normalized_entity is not None:
|
||||
normalized_results.append(normalized_entity)
|
||||
normalized_map[format_entity_id_for_models(entity)] = normalized_entity
|
||||
else:
|
||||
normalized_map[format_entity_id_for_models(entity)] = entity
|
||||
|
||||
return NormalizedEntities(
|
||||
entities=normalized_results, entity_normalization_map=normalized_map
|
||||
)
|
||||
|
||||
|
||||
def normalize_entities_w_attributes_from_map(
|
||||
raw_entities_w_attributes: list[str],
|
||||
entity_normalization_map: dict[str, str],
|
||||
) -> list[str]:
|
||||
"""
|
||||
Normalize entities with attributes using the entity normalization map.
|
||||
"""
|
||||
|
||||
normalized_entities_w_attributes: list[str] = []
|
||||
|
||||
for raw_entities_w_attribute in raw_entities_w_attributes:
|
||||
assert (
|
||||
len(raw_entities_w_attribute.split("--")) == 2
|
||||
), f"Invalid entity with attributes: {raw_entities_w_attribute}"
|
||||
raw_entity, attributes = raw_entities_w_attribute.split("--")
|
||||
formatted_raw_entity = format_entity_id_for_models(raw_entity)
|
||||
normalized_entity = entity_normalization_map.get(formatted_raw_entity)
|
||||
if normalized_entity is None:
|
||||
logger.warning(f"No normalized entity found for {raw_entity}")
|
||||
continue
|
||||
else:
|
||||
normalized_entities_w_attributes.append(
|
||||
f"{normalized_entity}--{raw_entities_w_attribute.split('--')[1].strip()}"
|
||||
)
|
||||
|
||||
return normalized_entities_w_attributes
|
||||
|
||||
|
||||
def normalize_relationships(
|
||||
raw_relationships: list[str], entity_normalization_map: dict[str, str]
|
||||
) -> NormalizedRelationships:
|
||||
"""
|
||||
Normalize relationships using entity mappings and relationship string matching.
|
||||
|
||||
Args:
|
||||
relationships: list of relationships in format "source__relation__target"
|
||||
entity_normalization_map: Mapping of raw entities to normalized ones (or None)
|
||||
|
||||
Returns:
|
||||
NormalizedRelationships containing normalized relationships and mapping
|
||||
"""
|
||||
# Placeholder for normalized relationship structure
|
||||
nor_relationships = _get_existing_normalized_relationships(raw_relationships)
|
||||
|
||||
normalized_rels: list[str] = []
|
||||
normalization_map: dict[str, str] = {}
|
||||
|
||||
for raw_rel in raw_relationships:
|
||||
# 1. Split and normalize entities
|
||||
try:
|
||||
source, rel_string, target = split_relationship_id(raw_rel)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid relationship format: {raw_rel}")
|
||||
|
||||
# Check if entities are in normalization map and not None
|
||||
norm_source = entity_normalization_map.get(source)
|
||||
norm_target = entity_normalization_map.get(target)
|
||||
|
||||
if norm_source is None or norm_target is None:
|
||||
logger.warning(f"No normalized entities found for {raw_rel}")
|
||||
continue
|
||||
|
||||
# 2. Find candidate normalized relationships
|
||||
candidate_rels = []
|
||||
norm_source_type = get_entity_type(norm_source)
|
||||
norm_target_type = get_entity_type(norm_target)
|
||||
if (
|
||||
norm_source_type in nor_relationships
|
||||
and norm_target_type in nor_relationships[norm_source_type]
|
||||
):
|
||||
candidate_rels = [
|
||||
split_relationship_id(rel)[1]
|
||||
for rel in nor_relationships[norm_source_type][norm_target_type]
|
||||
]
|
||||
|
||||
if not candidate_rels:
|
||||
logger.warning(f"No candidate relationships found for {raw_rel}")
|
||||
continue
|
||||
|
||||
# 3. Encode and find best match
|
||||
strings_to_encode = [rel_string] + candidate_rels
|
||||
vectors = encode_string_batch(strings_to_encode)
|
||||
|
||||
# Get raw relation vector and candidate vectors
|
||||
raw_vector = vectors[0]
|
||||
candidate_vectors = vectors[1:]
|
||||
|
||||
# Calculate dot products
|
||||
dot_products = np.dot(candidate_vectors, raw_vector)
|
||||
best_match_idx = np.argmax(dot_products)
|
||||
|
||||
# Create normalized relationship
|
||||
norm_rel = make_relationship_id(
|
||||
norm_source, candidate_rels[best_match_idx], norm_target
|
||||
)
|
||||
normalized_rels.append(norm_rel)
|
||||
normalization_map[raw_rel] = norm_rel
|
||||
|
||||
return NormalizedRelationships(
|
||||
relationships=normalized_rels, relationship_normalization_map=normalization_map
|
||||
)
|
||||
|
||||
|
||||
def normalize_terms(raw_terms: list[str]) -> NormalizedTerms:
|
||||
"""
|
||||
Normalize terms using semantic similarity matching.
|
||||
|
||||
Args:
|
||||
terms: list of terms to normalize
|
||||
|
||||
Returns:
|
||||
NormalizedTerms containing normalized terms and mapping
|
||||
"""
|
||||
# # Placeholder for normalized terms - this would typically come from a predefined list
|
||||
# normalized_term_list = [
|
||||
# "algorithm",
|
||||
# "database",
|
||||
# "software",
|
||||
# "programming",
|
||||
# # ... other normalized terms ...
|
||||
# ]
|
||||
|
||||
# normalized_terms: list[str] = []
|
||||
# normalization_map: dict[str, str | None] = {}
|
||||
|
||||
# if not raw_terms:
|
||||
# return NormalizedTerms(terms=[], term_normalization_map={})
|
||||
|
||||
# # Encode all terms at once for efficiency
|
||||
# strings_to_encode = raw_terms + normalized_term_list
|
||||
# vectors = encode_string_batch(strings_to_encode)
|
||||
|
||||
# # Split vectors into query terms and candidate terms
|
||||
# query_vectors = vectors[:len(raw_terms)]
|
||||
# candidate_vectors = vectors[len(raw_terms):]
|
||||
|
||||
# # Calculate similarity for each term
|
||||
# for i, term in enumerate(raw_terms):
|
||||
# # Calculate dot products with all candidates
|
||||
# similarities = np.dot(candidate_vectors, query_vectors[i])
|
||||
# best_match_idx = np.argmax(similarities)
|
||||
# best_match_score = similarities[best_match_idx]
|
||||
|
||||
# # Use a threshold to determine if the match is good enough
|
||||
# if best_match_score > 0.7: # Adjust threshold as needed
|
||||
# normalized_term = normalized_term_list[best_match_idx]
|
||||
# normalized_terms.append(normalized_term)
|
||||
# normalization_map[term] = normalized_term
|
||||
# else:
|
||||
# # If no good match found, keep original
|
||||
# normalization_map[term] = None
|
||||
|
||||
# return NormalizedTerms(
|
||||
# terms=normalized_terms,
|
||||
# term_normalization_map=normalization_map
|
||||
# )
|
||||
|
||||
return NormalizedTerms(
|
||||
terms=raw_terms, term_normalization_map={term: term for term in raw_terms}
|
||||
)
|
||||
53
backend/onyx/kg/configuration.py
Normal file
53
backend/onyx/kg/configuration.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entity_type import populate_default_employee_account_information
|
||||
from onyx.db.entity_type import (
|
||||
populate_default_primary_grounded_entity_type_information,
|
||||
)
|
||||
from onyx.db.kg_config import get_kg_enablement
|
||||
from onyx.db.kg_config import KGConfigSettings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def populate_default_grounded_entity_types() -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if not get_kg_enablement(db_session):
|
||||
logger.error(
|
||||
"KG approach is not enabled, the entity types cannot be populated."
|
||||
)
|
||||
raise ValueError(
|
||||
"KG approach is not enabled, the entity types cannot be populated."
|
||||
)
|
||||
|
||||
populate_default_primary_grounded_entity_type_information(db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def populate_default_account_employee_definitions() -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if not get_kg_enablement(db_session):
|
||||
logger.error(
|
||||
"KG approach is not enabled, the entity types cannot be populated."
|
||||
)
|
||||
raise ValueError(
|
||||
"KG approach is not enabled, the entity types cannot be populated."
|
||||
)
|
||||
|
||||
populate_default_employee_account_information(db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_kg_settings(kg_config_settings: KGConfigSettings) -> None:
|
||||
if not kg_config_settings.KG_ENABLED:
|
||||
raise ValueError("KG is not enabled")
|
||||
if not kg_config_settings.KG_VENDOR:
|
||||
raise ValueError("KG_VENDOR is not set")
|
||||
if not kg_config_settings.KG_VENDOR_DOMAINS:
|
||||
raise ValueError("KG_VENDOR_DOMAINS is not set")
|
||||
1401
backend/onyx/kg/extractions/extraction_processing.py
Normal file
1401
backend/onyx/kg/extractions/extraction_processing.py
Normal file
File diff suppressed because it is too large
Load Diff
100
backend/onyx/kg/kg_default_entity_definitions.py
Normal file
100
backend/onyx/kg/kg_default_entity_definitions.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.kg.models import KGDefaultEntityDefinition
|
||||
from onyx.kg.models import KGGroundingType
|
||||
|
||||
|
||||
class KGDefaultPrimaryGroundedEntityDefinitions(BaseModel):
|
||||
|
||||
LINEAR: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A formal ticket about a product issue or improvement request.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="linear",
|
||||
)
|
||||
|
||||
GITHUB_PR: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="Our (---vendor_name---) Engineering PRs describing what was actually implemented.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="github",
|
||||
)
|
||||
|
||||
FIREFLIES: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A phone call transcript between us (---vendor_name---) \
|
||||
and another account or individuals, or an internal meeting.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="fireflies",
|
||||
)
|
||||
|
||||
GONG: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A phone call transcript between us (---vendor_name---) \
|
||||
and another account or individuals, or an internal meeting.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="gong",
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A Google Drive document.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="google_drive",
|
||||
)
|
||||
|
||||
GMAIL: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="An email.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="gmail",
|
||||
)
|
||||
|
||||
JIRA: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A formal JIRA ticket about a product issue or improvement request.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="jira",
|
||||
)
|
||||
|
||||
ACCOUNT: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A company that was, is, or potentially could be a customer of the vendor \
|
||||
('us, ---vendor_name---'). Note that ---vendor_name--- can never be an ACCOUNT.",
|
||||
attributes={
|
||||
"metadata_attributes": {},
|
||||
"entity_filter_attributes": {"object_type": "Account"},
|
||||
"classification_attributes": {},
|
||||
},
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="salesforce",
|
||||
)
|
||||
OPPORTUNITY: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A sales opportunity.",
|
||||
attributes={
|
||||
"metadata_attributes": {},
|
||||
"entity_filter_attributes": {"object_type": "Opportunity"},
|
||||
"classification_attributes": {},
|
||||
},
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="salesforce",
|
||||
)
|
||||
|
||||
|
||||
class KGDefaultAccountEmployeeDefinitions(BaseModel):
|
||||
|
||||
VENDOR: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="The Vendor ---vendor_name---, 'us'",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
active=False,
|
||||
grounded_source_name=None,
|
||||
)
|
||||
|
||||
ACCOUNT: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A company that was, is, or potentially could be a customer of the vendor \
|
||||
('us, ---vendor_name---'). Note that ---vendor_name--- can never be an ACCOUNT.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
active=False,
|
||||
grounded_source_name=None,
|
||||
)
|
||||
|
||||
EMPLOYEE: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A person who speaks on \
|
||||
behalf of 'our' company (the VENDOR ---vendor_name---), NOT of another account. Therefore, employees of other companies \
|
||||
are NOT included here. If in doubt, do NOT extract.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
active=False,
|
||||
grounded_source_name=None,
|
||||
)
|
||||
223
backend/onyx/kg/models.py
Normal file
223
backend/onyx/kg/models.py
Normal file
@@ -0,0 +1,223 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.kg_configs import KG_DEFAULT_MAX_PARENT_RECURSION_DEPTH
|
||||
|
||||
|
||||
class KGConfigSettings(BaseModel):
|
||||
KG_EXPOSED: bool = False
|
||||
KG_ENABLED: bool = False
|
||||
KG_VENDOR: str | None = None
|
||||
KG_VENDOR_DOMAINS: list[str] | None = None
|
||||
KG_IGNORE_EMAIL_DOMAINS: list[str] | None = None
|
||||
KG_EXTRACTION_IN_PROGRESS: bool = False
|
||||
KG_CLUSTERING_IN_PROGRESS: bool = False
|
||||
KG_COVERAGE_START: datetime = datetime(1970, 1, 1)
|
||||
KG_MAX_COVERAGE_DAYS: int = 10000
|
||||
KG_MAX_PARENT_RECURSION_DEPTH: int = KG_DEFAULT_MAX_PARENT_RECURSION_DEPTH
|
||||
|
||||
|
||||
class KGConfigVars(str, Enum):
|
||||
KG_EXPOSED = "KG_EXPOSED"
|
||||
KG_ENABLED = "KG_ENABLED"
|
||||
KG_VENDOR = "KG_VENDOR"
|
||||
KG_VENDOR_DOMAINS = "KG_VENDOR_DOMAINS"
|
||||
KG_IGNORE_EMAIL_DOMAINS = "KG_IGNORE_EMAIL_DOMAINS"
|
||||
KG_EXTRACTION_IN_PROGRESS = "KG_EXTRACTION_IN_PROGRESS"
|
||||
KG_CLUSTERING_IN_PROGRESS = "KG_CLUSTERING_IN_PROGRESS"
|
||||
KG_COVERAGE_START = "KG_COVERAGE_START"
|
||||
KG_MAX_COVERAGE_DAYS = "KG_MAX_COVERAGE_DAYS"
|
||||
KG_MAX_PARENT_RECURSION_DEPTH = "KG_MAX_PARENT_RECURSION_DEPTH"
|
||||
|
||||
|
||||
class KGChunkFormat(BaseModel):
|
||||
connector_id: int | None = None
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
title: str
|
||||
content: str
|
||||
primary_owners: list[str]
|
||||
secondary_owners: list[str]
|
||||
source_type: str
|
||||
metadata: dict[str, str | list[str]] | None = None
|
||||
entities: dict[str, int] = {}
|
||||
relationships: dict[str, int] = {}
|
||||
terms: dict[str, int] = {}
|
||||
deep_extraction: bool = False
|
||||
|
||||
|
||||
class KGChunkExtraction(BaseModel):
|
||||
connector_id: int
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
core_entity: str
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
terms: list[str]
|
||||
attributes: dict[str, str | list[str]]
|
||||
|
||||
|
||||
class KGChunkId(BaseModel):
|
||||
connector_id: int | None = None
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
|
||||
|
||||
class KGRelationshipExtraction(BaseModel):
|
||||
relationship_str: str
|
||||
source_document_id: str
|
||||
|
||||
|
||||
class KGAggregatedExtractions(BaseModel):
|
||||
grounded_entities_document_ids: dict[str, str]
|
||||
entities: dict[str, int]
|
||||
relationships: dict[str, dict[str, int]]
|
||||
terms: dict[str, int]
|
||||
attributes: dict[str, dict[str, str | list[str]]]
|
||||
|
||||
|
||||
class KGBatchExtractionStats(BaseModel):
|
||||
connector_id: int | None = None
|
||||
succeeded: list[KGChunkId]
|
||||
failed: list[KGChunkId]
|
||||
aggregated_kg_extractions: KGAggregatedExtractions
|
||||
|
||||
|
||||
class ConnectorExtractionStats(BaseModel):
|
||||
connector_id: int
|
||||
num_succeeded: int
|
||||
num_failed: int
|
||||
num_processed: int
|
||||
|
||||
|
||||
class KGPerson(BaseModel):
|
||||
name: str
|
||||
company: str
|
||||
employee: bool
|
||||
|
||||
|
||||
class NormalizedEntities(BaseModel):
|
||||
entities: list[str]
|
||||
entity_normalization_map: dict[str, str]
|
||||
|
||||
|
||||
class NormalizedRelationships(BaseModel):
|
||||
relationships: list[str]
|
||||
relationship_normalization_map: dict[str, str]
|
||||
|
||||
|
||||
class NormalizedTerms(BaseModel):
|
||||
terms: list[str]
|
||||
term_normalization_map: dict[str, str | None]
|
||||
|
||||
|
||||
class KGClassificationContent(BaseModel):
|
||||
document_id: str
|
||||
classification_content: str
|
||||
source_type: str
|
||||
source_metadata: dict[str, Any] | None = None
|
||||
entity_type: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class KGEnrichedClassificationContent(KGClassificationContent):
|
||||
classification_enabled: bool
|
||||
classification_instructions: dict[str, Any]
|
||||
deep_extraction: bool
|
||||
|
||||
|
||||
class KGClassificationDecisions(BaseModel):
|
||||
document_id: str
|
||||
classification_decision: bool
|
||||
classification_class: str | None
|
||||
source_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class KGClassificationInstructions(BaseModel):
|
||||
classification_enabled: bool
|
||||
classification_options: str
|
||||
classification_class_definitions: dict[str, dict[str, str | bool]]
|
||||
|
||||
|
||||
class KGExtractionInstructions(BaseModel):
|
||||
deep_extraction: bool
|
||||
active: bool
|
||||
|
||||
|
||||
class KGEntityTypeInstructions(BaseModel):
|
||||
classification_instructions: KGClassificationInstructions
|
||||
extraction_instructions: KGExtractionInstructions
|
||||
filter_instructions: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class KGEnhancedDocumentMetadata(BaseModel):
|
||||
entity_type: str | None
|
||||
document_attributes: dict[str, Any] | None
|
||||
deep_extraction: bool
|
||||
classification_enabled: bool
|
||||
classification_instructions: KGClassificationInstructions | None
|
||||
skip: bool
|
||||
|
||||
|
||||
class ContextPreparation(BaseModel):
|
||||
"""
|
||||
Context preparation format for the LLM KG extraction.
|
||||
"""
|
||||
|
||||
llm_context: str
|
||||
core_entity: str
|
||||
implied_entities: list[str]
|
||||
implied_relationships: list[str]
|
||||
implied_terms: list[str]
|
||||
|
||||
|
||||
class KGDocumentClassificationPrompt(BaseModel):
|
||||
"""
|
||||
Document classification prompt format for the LLM KG extraction.
|
||||
"""
|
||||
|
||||
llm_prompt: str | None
|
||||
|
||||
|
||||
class KGConnectorData(BaseModel):
|
||||
id: int
|
||||
source: str
|
||||
kg_coverage_days: int | None
|
||||
|
||||
|
||||
class KGStage(str, Enum):
|
||||
EXTRACTED = "extracted"
|
||||
NORMALIZED = "normalized"
|
||||
FAILED = "failed"
|
||||
SKIPPED = "skipped"
|
||||
NOT_STARTED = "not_started"
|
||||
EXTRACTING = "extracting"
|
||||
DO_NOT_EXTRACT = "do_not_extract"
|
||||
|
||||
|
||||
class KGDocumentEntitiesRelationshipsAttributes(BaseModel):
|
||||
kg_core_document_id_name: str
|
||||
implied_entities: set[str]
|
||||
implied_relationships: set[str]
|
||||
converted_relationships_to_attributes: dict[str, list[str]]
|
||||
company_participant_emails: set[str]
|
||||
account_participant_emails: set[str]
|
||||
converted_attributes_to_relationships: set[str]
|
||||
document_attributes: dict[str, Any] | None
|
||||
|
||||
|
||||
class KGGroundingType(str, Enum):
|
||||
UNGROUNDED = "ungrounded"
|
||||
GROUNDED = "grounded"
|
||||
|
||||
|
||||
class KGDefaultEntityDefinition(BaseModel):
|
||||
description: str
|
||||
grounding: KGGroundingType
|
||||
active: bool = False
|
||||
grounded_source_name: str | None
|
||||
attributes: dict = {}
|
||||
entity_values: dict = {}
|
||||
21
backend/onyx/kg/resets/reset_extractions.py
Normal file
21
backend/onyx/kg/resets/reset_extractions.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from onyx.db.document import update_document_kg_stages
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KGEntityExtractionStaging
|
||||
from onyx.db.models import KGRelationshipExtractionStaging
|
||||
from onyx.db.models import KGRelationshipTypeExtractionStaging
|
||||
from onyx.kg.models import KGStage
|
||||
|
||||
|
||||
def reset_extraction_kg_index() -> None:
|
||||
"""
|
||||
Resets the knowledge graph index.
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.query(KGRelationshipExtractionStaging).delete()
|
||||
db_session.query(KGEntityExtractionStaging).delete()
|
||||
db_session.query(KGRelationshipTypeExtractionStaging).delete()
|
||||
db_session.commit()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
update_document_kg_stages(db_session, KGStage.EXTRACTED, KGStage.NOT_STARTED)
|
||||
db_session.commit()
|
||||
26
backend/onyx/kg/resets/reset_index.py
Normal file
26
backend/onyx/kg/resets/reset_index.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from onyx.db.document import reset_all_document_kg_stages
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KGEntity
|
||||
from onyx.db.models import KGEntityExtractionStaging
|
||||
from onyx.db.models import KGRelationship
|
||||
from onyx.db.models import KGRelationshipExtractionStaging
|
||||
from onyx.db.models import KGRelationshipType
|
||||
from onyx.db.models import KGRelationshipTypeExtractionStaging
|
||||
|
||||
|
||||
def reset_full_kg_index() -> None:
|
||||
"""
|
||||
Resets the knowledge graph index.
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.query(KGRelationship).delete()
|
||||
db_session.query(KGRelationshipType).delete()
|
||||
db_session.query(KGEntity).delete()
|
||||
db_session.query(KGRelationshipExtractionStaging).delete()
|
||||
db_session.query(KGEntityExtractionStaging).delete()
|
||||
db_session.query(KGRelationshipTypeExtractionStaging).delete()
|
||||
db_session.commit()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
reset_all_document_kg_stages(db_session)
|
||||
db_session.commit()
|
||||
22
backend/onyx/kg/resets/reset_normalizations.py
Normal file
22
backend/onyx/kg/resets/reset_normalizations.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from onyx.db.document import update_document_kg_stages
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KGEntity
|
||||
from onyx.db.models import KGRelationship
|
||||
from onyx.db.models import KGRelationshipType
|
||||
from onyx.kg.models import KGStage
|
||||
|
||||
|
||||
def reset_normalization_kg_index() -> None:
|
||||
"""
|
||||
Resets the knowledge graph index.
|
||||
"""
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.query(KGRelationship).delete()
|
||||
db_session.query(KGEntity).delete()
|
||||
db_session.query(KGRelationshipType).delete()
|
||||
db_session.commit()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
update_document_kg_stages(db_session, KGStage.NORMALIZED, KGStage.EXTRACTED)
|
||||
db_session.commit()
|
||||
92
backend/onyx/kg/resets/reset_source.py
Normal file
92
backend/onyx/kg/resets/reset_source.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from sqlalchemy import or_
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import KGEntity
|
||||
from onyx.db.models import KGEntityExtractionStaging
|
||||
from onyx.db.models import KGEntityType
|
||||
from onyx.db.models import KGRelationship
|
||||
from onyx.db.models import KGRelationshipExtractionStaging
|
||||
from onyx.db.models import KGRelationshipType
|
||||
from onyx.db.models import KGRelationshipTypeExtractionStaging
|
||||
from onyx.db.models import KGStage
|
||||
from onyx.kg.resets.reset_vespa import reset_vespa_kg_index
|
||||
|
||||
|
||||
def reset_source_kg_index(source_name: str, tenant_id: str, index_name: str) -> None:
|
||||
"""
|
||||
Resets the knowledge graph index and vespa for a source.
|
||||
"""
|
||||
# reset vespa for the source
|
||||
reset_vespa_kg_index(tenant_id, index_name, source_name)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# get all the entity types for the given source
|
||||
entity_types = [
|
||||
et.id_name
|
||||
for et in db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.grounded_source_name == source_name)
|
||||
.all()
|
||||
]
|
||||
if not entity_types:
|
||||
raise ValueError(f"There are no entity types for the source {source_name}")
|
||||
|
||||
# delete the entity type from the knowledge graph
|
||||
for entity_type in entity_types:
|
||||
db_session.query(KGRelationship).filter(
|
||||
or_(
|
||||
KGRelationship.source_node_type == entity_type,
|
||||
KGRelationship.target_node_type == entity_type,
|
||||
)
|
||||
).delete()
|
||||
db_session.query(KGRelationshipType).filter(
|
||||
or_(
|
||||
KGRelationshipType.source_entity_type_id_name == entity_type,
|
||||
KGRelationshipType.target_entity_type_id_name == entity_type,
|
||||
)
|
||||
).delete()
|
||||
db_session.query(KGEntity).filter(
|
||||
KGEntity.entity_type_id_name == entity_type
|
||||
).delete()
|
||||
db_session.query(KGRelationshipExtractionStaging).filter(
|
||||
or_(
|
||||
KGRelationshipExtractionStaging.source_node_type == entity_type,
|
||||
KGRelationshipExtractionStaging.target_node_type == entity_type,
|
||||
)
|
||||
).delete()
|
||||
db_session.query(KGEntityExtractionStaging).filter(
|
||||
KGEntityExtractionStaging.entity_type_id_name == entity_type
|
||||
).delete()
|
||||
db_session.query(KGRelationshipTypeExtractionStaging).filter(
|
||||
or_(
|
||||
KGRelationshipTypeExtractionStaging.source_entity_type_id_name
|
||||
== entity_type,
|
||||
KGRelationshipTypeExtractionStaging.target_entity_type_id_name
|
||||
== entity_type,
|
||||
)
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# get all the documents for the given source
|
||||
kg_connectors = [
|
||||
connector.id
|
||||
for connector in db_session.query(Connector)
|
||||
.filter(Connector.source == DocumentSource(source_name))
|
||||
.all()
|
||||
]
|
||||
document_ids = [
|
||||
cc_pair.id
|
||||
for cc_pair in db_session.query(DocumentByConnectorCredentialPair)
|
||||
.filter(DocumentByConnectorCredentialPair.connector_id.in_(kg_connectors))
|
||||
.all()
|
||||
]
|
||||
|
||||
# reset the kg stage for the documents
|
||||
db_session.query(Document).filter(Document.id.in_(document_ids)).update(
|
||||
{"kg_stage": KGStage.NOT_STARTED}
|
||||
)
|
||||
db_session.commit()
|
||||
122
backend/onyx/kg/resets/reset_vespa.py
Normal file
122
backend/onyx/kg/resets/reset_vespa.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from typing import Any
|
||||
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import KGEntityType
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info
|
||||
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
|
||||
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
|
||||
from onyx.document_index.vespa.index import IndexFilters
|
||||
from onyx.document_index.vespa.index import KGVespaChunkUpdateRequest
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _reset_vespa_for_doc(document_id: str, tenant_id: str, index_name: str) -> None:
|
||||
vespa_index = VespaIndex(
|
||||
index_name=index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=False,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=None,
|
||||
)
|
||||
|
||||
reset_update_dict: dict[str, Any] = {
|
||||
"fields": {
|
||||
"kg_entities": {"assign": {}},
|
||||
"kg_relationships": {"assign": {}},
|
||||
"kg_terms": {"assign": {}},
|
||||
}
|
||||
}
|
||||
|
||||
chunks = _get_chunks_via_visit_api(
|
||||
VespaChunkRequest(document_id=document_id),
|
||||
index_name,
|
||||
IndexFilters(access_control_list=None),
|
||||
["chunk_id"],
|
||||
False,
|
||||
)
|
||||
|
||||
vespa_requests: list[KGVespaChunkUpdateRequest] = []
|
||||
for chunk in chunks:
|
||||
doc_chunk_id = get_uuid_from_chunk_info(
|
||||
document_id=document_id,
|
||||
chunk_id=chunk["fields"]["chunk_id"],
|
||||
tenant_id=tenant_id,
|
||||
large_chunk_id=None,
|
||||
)
|
||||
vespa_requests.append(
|
||||
KGVespaChunkUpdateRequest(
|
||||
document_id=document_id,
|
||||
chunk_id=chunk["fields"]["chunk_id"],
|
||||
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=vespa_index.index_name)}/{doc_chunk_id}",
|
||||
update_request=reset_update_dict,
|
||||
)
|
||||
)
|
||||
|
||||
with vespa_index.httpx_client_context as httpx_client:
|
||||
vespa_index._apply_kg_chunk_updates_batched(vespa_requests, httpx_client)
|
||||
|
||||
|
||||
def reset_vespa_kg_index(
|
||||
tenant_id: str, index_name: str, source_name: str | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Reset the kg info in vespa for all documents of a given source name,
|
||||
or all documents from kg grounded sources if source_name is None.
|
||||
"""
|
||||
logger.info(
|
||||
f"Resetting kg vespa index {index_name} for tenant {tenant_id}, "
|
||||
f"source: {source_name if source_name else 'all'}"
|
||||
)
|
||||
|
||||
# Get all documents that need a vespa reset
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if source_name:
|
||||
# get all connectors of the given source name
|
||||
kg_connectors = [
|
||||
connector.id
|
||||
for connector in db_session.query(Connector)
|
||||
.filter(Connector.source == DocumentSource(source_name))
|
||||
.all()
|
||||
]
|
||||
else:
|
||||
# get all connectors that have kg enabled
|
||||
kg_sources = [
|
||||
DocumentSource(et.grounded_source_name)
|
||||
for et in db_session.query(KGEntityType)
|
||||
.filter(
|
||||
KGEntityType.grounded_source_name.is_not(None),
|
||||
KGEntityType.active.is_(True),
|
||||
)
|
||||
.distinct()
|
||||
.all()
|
||||
]
|
||||
kg_connectors = [
|
||||
connector.id
|
||||
for connector in db_session.query(Connector)
|
||||
.filter(Connector.source.in_(kg_sources))
|
||||
.all()
|
||||
]
|
||||
|
||||
# Get all the documents for the given connectors
|
||||
document_ids = [
|
||||
cc_pair.id
|
||||
for cc_pair in db_session.query(DocumentByConnectorCredentialPair)
|
||||
.filter(DocumentByConnectorCredentialPair.connector_id.in_(kg_connectors))
|
||||
.all()
|
||||
]
|
||||
|
||||
# Reset the kg fields
|
||||
for document_id in document_ids:
|
||||
_reset_vespa_for_doc(document_id, tenant_id, index_name)
|
||||
23
backend/onyx/kg/utils/embeddings.py
Normal file
23
backend/onyx/kg/utils/embeddings.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbedTextType
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
|
||||
def encode_string_batch(strings: List[str]) -> np.ndarray:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=current_search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
# Get embeddings while session is still open
|
||||
embedding = model.encode(strings, text_type=EmbedTextType.QUERY)
|
||||
return np.array(embedding)
|
||||
443
backend/onyx/kg/utils/extraction_utils.py
Normal file
443
backend/onyx/kg/utils/extraction_utils.py
Normal file
@@ -0,0 +1,443 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
|
||||
from onyx.configs.constants import OnyxCallTypes
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_kg_entity_by_document
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.kg_config import KGConfigSettings
|
||||
from onyx.db.models import Document
|
||||
from onyx.kg.models import KGChunkFormat
|
||||
from onyx.kg.models import KGClassificationContent
|
||||
from onyx.kg.models import (
|
||||
KGDocumentClassificationPrompt,
|
||||
)
|
||||
from onyx.kg.models import KGDocumentEntitiesRelationshipsAttributes
|
||||
from onyx.kg.models import KGEnhancedDocumentMetadata
|
||||
from onyx.kg.utils.formatting_utils import generalize_entities
|
||||
from onyx.kg.utils.formatting_utils import kg_email_processing
|
||||
from onyx.kg.utils.formatting_utils import make_entity_id
|
||||
from onyx.kg.utils.formatting_utils import make_relationship_id
|
||||
from onyx.prompts.kg_prompts import CALL_CHUNK_PREPROCESSING_PROMPT
|
||||
from onyx.prompts.kg_prompts import CALL_DOCUMENT_CLASSIFICATION_PROMPT
|
||||
from onyx.prompts.kg_prompts import GENERAL_CHUNK_PREPROCESSING_PROMPT
|
||||
|
||||
|
||||
def _update_implied_entities_relationships(
|
||||
kg_core_document_id_name: str,
|
||||
owner_list: list[str],
|
||||
implied_entities: set[str],
|
||||
implied_relationships: set[str],
|
||||
company_participant_emails: set[str],
|
||||
account_participant_emails: set[str],
|
||||
relationship_type: str,
|
||||
kg_config_settings: KGConfigSettings,
|
||||
converted_relationships_to_attributes: dict[str, list[str]],
|
||||
) -> tuple[set[str], set[str], set[str], set[str], dict[str, list[str]]]:
|
||||
|
||||
for owner in owner_list or []:
|
||||
if is_email(owner):
|
||||
(
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
) = kg_process_person(
|
||||
owner,
|
||||
kg_core_document_id_name,
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
relationship_type,
|
||||
kg_config_settings,
|
||||
)
|
||||
else:
|
||||
converted_relationships_to_attributes[relationship_type].append(owner)
|
||||
|
||||
return (
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
converted_relationships_to_attributes,
|
||||
)
|
||||
|
||||
|
||||
def kg_document_entities_relationships_attribute_generation(
|
||||
document: Document,
|
||||
doc_metadata: KGEnhancedDocumentMetadata,
|
||||
active_entities: list[str],
|
||||
kg_config_settings: KGConfigSettings,
|
||||
) -> KGDocumentEntitiesRelationshipsAttributes:
|
||||
"""
|
||||
Generate entities, relationships, and attributes for a document.
|
||||
"""
|
||||
|
||||
# Get document entity type from the KGEnhancedDocumentMetadata
|
||||
document_entity_type = doc_metadata.entity_type
|
||||
assert document_entity_type is not None
|
||||
|
||||
# Get additional document attributes from the KGEnhancedDocumentMetadata
|
||||
document_attributes = doc_metadata.document_attributes
|
||||
|
||||
implied_entities: set[str] = set()
|
||||
implied_relationships: set[str] = (
|
||||
set()
|
||||
) # 'Relationships' that will be captured as KG relationships
|
||||
converted_relationships_to_attributes: dict[str, list[str]] = defaultdict(
|
||||
list
|
||||
) # 'Relationships' that will be captured as KG entity attributes
|
||||
|
||||
converted_attributes_to_relationships: set[str] = (
|
||||
set()
|
||||
) # Attributes that should be captures as entities and then relationships (Account = ...)
|
||||
|
||||
company_participant_emails: set[str] = (
|
||||
set()
|
||||
) # Quantity needed for call processing - participants from vendor
|
||||
account_participant_emails: set[str] = (
|
||||
set()
|
||||
) # Quantity needed for call processing - external participants
|
||||
|
||||
# Chunk treatment variables
|
||||
|
||||
document_is_from_call = document_entity_type.lower() in [
|
||||
call_type.value.lower() for call_type in OnyxCallTypes
|
||||
]
|
||||
|
||||
# Get core entity
|
||||
|
||||
document_id = document.id
|
||||
primary_owners = document.primary_owners
|
||||
secondary_owners = document.secondary_owners
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
kg_core_document = get_kg_entity_by_document(db_session, document_id)
|
||||
|
||||
if kg_core_document:
|
||||
kg_core_document_id_name = kg_core_document.id_name
|
||||
else:
|
||||
kg_core_document_id_name = make_entity_id(document_entity_type, document_id)
|
||||
|
||||
# Get implied entities and relationships from primary/secondary owners
|
||||
|
||||
if document_is_from_call:
|
||||
(
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
converted_relationships_to_attributes,
|
||||
) = _update_implied_entities_relationships(
|
||||
kg_core_document_id_name,
|
||||
owner_list=(primary_owners or []) + (secondary_owners or []),
|
||||
implied_entities=implied_entities,
|
||||
implied_relationships=implied_relationships,
|
||||
company_participant_emails=company_participant_emails,
|
||||
account_participant_emails=account_participant_emails,
|
||||
relationship_type="participates_in",
|
||||
kg_config_settings=kg_config_settings,
|
||||
converted_relationships_to_attributes=converted_relationships_to_attributes,
|
||||
)
|
||||
else:
|
||||
(
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
converted_relationships_to_attributes,
|
||||
) = _update_implied_entities_relationships(
|
||||
kg_core_document_id_name,
|
||||
owner_list=primary_owners or [],
|
||||
implied_entities=implied_entities,
|
||||
implied_relationships=implied_relationships,
|
||||
company_participant_emails=company_participant_emails,
|
||||
account_participant_emails=account_participant_emails,
|
||||
relationship_type="leads",
|
||||
kg_config_settings=kg_config_settings,
|
||||
converted_relationships_to_attributes=converted_relationships_to_attributes,
|
||||
)
|
||||
|
||||
(
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
converted_relationships_to_attributes,
|
||||
) = _update_implied_entities_relationships(
|
||||
kg_core_document_id_name,
|
||||
owner_list=secondary_owners or [],
|
||||
implied_entities=implied_entities,
|
||||
implied_relationships=implied_relationships,
|
||||
company_participant_emails=company_participant_emails,
|
||||
account_participant_emails=account_participant_emails,
|
||||
relationship_type="participates_in",
|
||||
kg_config_settings=kg_config_settings,
|
||||
converted_relationships_to_attributes=converted_relationships_to_attributes,
|
||||
)
|
||||
|
||||
if document_attributes is not None:
|
||||
cleaned_document_attributes = document_attributes.copy()
|
||||
for attribute, value in document_attributes.items():
|
||||
if attribute.lower() in [x.lower() for x in active_entities]:
|
||||
converted_attributes_to_relationships.add(attribute)
|
||||
if isinstance(value, str):
|
||||
implied_entity = make_entity_id(attribute, value)
|
||||
implied_entities.add(implied_entity)
|
||||
implied_relationships.add(
|
||||
make_relationship_id(
|
||||
implied_entity,
|
||||
f"is_{attribute}_of",
|
||||
kg_core_document_id_name,
|
||||
)
|
||||
)
|
||||
|
||||
implied_entity = make_entity_id(attribute, "*")
|
||||
implied_entities.add(implied_entity)
|
||||
implied_relationships.add(
|
||||
make_relationship_id(
|
||||
implied_entity,
|
||||
f"is_{attribute}_of",
|
||||
kg_core_document_id_name,
|
||||
)
|
||||
)
|
||||
implied_relationships.add(
|
||||
make_relationship_id(
|
||||
implied_entity,
|
||||
f"is_{attribute}_of",
|
||||
make_entity_id(document_entity_type, "*"),
|
||||
)
|
||||
)
|
||||
|
||||
implied_entity = make_entity_id(attribute, value)
|
||||
implied_entities.add(implied_entity)
|
||||
implied_relationships.add(
|
||||
make_relationship_id(
|
||||
implied_entity,
|
||||
f"is_{attribute}_of",
|
||||
make_entity_id(document_entity_type, "*"),
|
||||
)
|
||||
)
|
||||
|
||||
cleaned_document_attributes.pop(attribute)
|
||||
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
implied_entity = make_entity_id(attribute, item)
|
||||
implied_entities.add(implied_entity)
|
||||
implied_relationships.add(
|
||||
make_relationship_id(
|
||||
implied_entity,
|
||||
f"is_{attribute}_of",
|
||||
kg_core_document_id_name,
|
||||
)
|
||||
)
|
||||
cleaned_document_attributes.pop(attribute)
|
||||
if attribute.lower().endswith("_id") or attribute.endswith("Id"):
|
||||
cleaned_document_attributes.pop(attribute)
|
||||
else:
|
||||
cleaned_document_attributes = None
|
||||
|
||||
return KGDocumentEntitiesRelationshipsAttributes(
|
||||
kg_core_document_id_name=kg_core_document_id_name,
|
||||
implied_entities=implied_entities,
|
||||
implied_relationships=implied_relationships,
|
||||
company_participant_emails=company_participant_emails,
|
||||
account_participant_emails=account_participant_emails,
|
||||
converted_relationships_to_attributes=converted_relationships_to_attributes,
|
||||
converted_attributes_to_relationships=converted_attributes_to_relationships,
|
||||
document_attributes=cleaned_document_attributes,
|
||||
)
|
||||
|
||||
|
||||
def _prepare_llm_document_content_call(
|
||||
document_classification_content: KGClassificationContent,
|
||||
category_list: str,
|
||||
category_definition_string: str,
|
||||
kg_config_settings: KGConfigSettings,
|
||||
) -> KGDocumentClassificationPrompt:
|
||||
"""
|
||||
Calls - prepare prompt for the LLM classification.
|
||||
"""
|
||||
|
||||
prompt = CALL_DOCUMENT_CLASSIFICATION_PROMPT.format(
|
||||
beginning_of_call_content=document_classification_content.classification_content,
|
||||
category_list=category_list,
|
||||
category_options=category_definition_string,
|
||||
vendor=kg_config_settings.KG_VENDOR,
|
||||
)
|
||||
|
||||
return KGDocumentClassificationPrompt(
|
||||
llm_prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
def kg_process_person(
|
||||
person: str,
|
||||
core_document_id_name: str,
|
||||
implied_entities: set[str],
|
||||
implied_relationships: set[str],
|
||||
company_participant_emails: set[str],
|
||||
account_participant_emails: set[str],
|
||||
relationship_type: str,
|
||||
kg_config_settings: KGConfigSettings,
|
||||
) -> tuple[set[str], set[str], set[str], set[str]]:
|
||||
"""
|
||||
Process a single owner and return updated sets with entities and relationships.
|
||||
|
||||
Returns:
|
||||
tuple containing (implied_entities, implied_relationships, company_participant_emails, account_participant_emails)
|
||||
"""
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
kg_config_settings = get_kg_config_settings(db_session)
|
||||
if not kg_config_settings.KG_ENABLED:
|
||||
raise ValueError("KG is not enabled")
|
||||
|
||||
assert isinstance(kg_config_settings.KG_IGNORE_EMAIL_DOMAINS, list)
|
||||
|
||||
kg_person = kg_email_processing(person, kg_config_settings)
|
||||
if any(
|
||||
domain.lower() in kg_person.company.lower()
|
||||
for domain in kg_config_settings.KG_IGNORE_EMAIL_DOMAINS
|
||||
):
|
||||
return (
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
)
|
||||
|
||||
if kg_person.employee:
|
||||
company_participant_emails = company_participant_emails | {
|
||||
f"{kg_person.name} -- ({kg_person.company})"
|
||||
}
|
||||
if kg_person.name not in implied_entities:
|
||||
target_general = list(generalize_entities([core_document_id_name]))[0]
|
||||
employee_entity = make_entity_id("EMPLOYEE", kg_person.name)
|
||||
employee_general = make_entity_id("EMPLOYEE", "*")
|
||||
|
||||
implied_entities.add(employee_entity)
|
||||
implied_relationships |= {
|
||||
make_relationship_id(
|
||||
employee_entity, relationship_type, core_document_id_name
|
||||
),
|
||||
make_relationship_id(
|
||||
employee_entity, relationship_type, target_general
|
||||
),
|
||||
make_relationship_id(
|
||||
employee_general, relationship_type, core_document_id_name
|
||||
),
|
||||
make_relationship_id(
|
||||
employee_general, relationship_type, target_general
|
||||
),
|
||||
}
|
||||
if kg_person.company not in implied_entities:
|
||||
company_entity = make_entity_id("VENDOR", kg_person.company)
|
||||
|
||||
implied_entities.add(company_entity)
|
||||
implied_relationships |= {
|
||||
make_relationship_id(
|
||||
company_entity, relationship_type, core_document_id_name
|
||||
),
|
||||
make_relationship_id(
|
||||
company_entity, relationship_type, target_general
|
||||
),
|
||||
}
|
||||
|
||||
else:
|
||||
account_participant_emails = account_participant_emails | {
|
||||
f"{kg_person.name} -- ({kg_person.company})"
|
||||
}
|
||||
if kg_person.company not in implied_entities:
|
||||
account_entity = make_entity_id("ACCOUNT", kg_person.company)
|
||||
account_general = make_entity_id("ACCOUNT", "*")
|
||||
target_general = list(generalize_entities([core_document_id_name]))[0]
|
||||
|
||||
implied_entities |= {account_entity, account_general}
|
||||
implied_relationships |= {
|
||||
make_relationship_id(
|
||||
account_entity, relationship_type, core_document_id_name
|
||||
),
|
||||
make_relationship_id(
|
||||
account_general, relationship_type, core_document_id_name
|
||||
),
|
||||
make_relationship_id(account_entity, relationship_type, target_general),
|
||||
make_relationship_id(
|
||||
account_general, relationship_type, target_general
|
||||
),
|
||||
}
|
||||
|
||||
return (
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
)
|
||||
|
||||
|
||||
def prepare_llm_content_extraction(
|
||||
chunk: KGChunkFormat,
|
||||
company_participant_emails: set[str],
|
||||
account_participant_emails: set[str],
|
||||
kg_config_settings: KGConfigSettings,
|
||||
) -> str:
|
||||
|
||||
chunk_is_from_call = chunk.source_type.lower() in [
|
||||
call_type.value.lower() for call_type in OnyxCallTypes
|
||||
]
|
||||
|
||||
if chunk_is_from_call:
|
||||
|
||||
llm_context = CALL_CHUNK_PREPROCESSING_PROMPT.format(
|
||||
participant_string=company_participant_emails,
|
||||
account_participant_string=account_participant_emails,
|
||||
vendor=kg_config_settings.KG_VENDOR,
|
||||
content=chunk.content,
|
||||
)
|
||||
else:
|
||||
llm_context = GENERAL_CHUNK_PREPROCESSING_PROMPT.format(
|
||||
content=chunk.content,
|
||||
vendor=kg_config_settings.KG_VENDOR,
|
||||
)
|
||||
|
||||
return llm_context
|
||||
|
||||
|
||||
def prepare_llm_document_content(
|
||||
document_classification_content: KGClassificationContent,
|
||||
category_list: str,
|
||||
category_definitions: dict[str, Dict[str, str | bool]],
|
||||
kg_config_settings: KGConfigSettings,
|
||||
) -> KGDocumentClassificationPrompt:
|
||||
"""
|
||||
Prepare the content for the extraction classification.
|
||||
"""
|
||||
|
||||
category_definition_string = ""
|
||||
for category, category_data in category_definitions.items():
|
||||
category_definition_string += f"{category}: {category_data['description']}\n"
|
||||
|
||||
if document_classification_content.source_type.lower() in [
|
||||
call_type.value.lower() for call_type in OnyxCallTypes
|
||||
]:
|
||||
return _prepare_llm_document_content_call(
|
||||
document_classification_content,
|
||||
category_list,
|
||||
category_definition_string,
|
||||
kg_config_settings,
|
||||
)
|
||||
|
||||
else:
|
||||
return KGDocumentClassificationPrompt(
|
||||
llm_prompt=None,
|
||||
)
|
||||
|
||||
|
||||
def is_email(email: str) -> bool:
|
||||
"""
|
||||
Check if a string is a valid email address.
|
||||
"""
|
||||
return re.match(r"[^@]+@[^@]+\.[^@]+", email) is not None
|
||||
182
backend/onyx/kg/utils/formatting_utils.py
Normal file
182
backend/onyx/kg/utils/formatting_utils.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from onyx.db.kg_config import KGConfigSettings
|
||||
from onyx.kg.models import KGAggregatedExtractions
|
||||
from onyx.kg.models import KGPerson
|
||||
|
||||
|
||||
def format_entity_id(entity_id_name: str) -> str:
|
||||
return make_entity_id(*split_entity_id(entity_id_name))
|
||||
|
||||
|
||||
def make_entity_id(entity_type: str, entity_name: str) -> str:
|
||||
return f"{entity_type.upper()}::{entity_name.lower()}"
|
||||
|
||||
|
||||
def split_entity_id(entity_id_name: str) -> list[str]:
|
||||
return entity_id_name.split("::")
|
||||
|
||||
|
||||
def get_entity_type(entity_id_name: str) -> str:
|
||||
return entity_id_name.split("::", 1)[0].upper()
|
||||
|
||||
|
||||
def format_entity_id_for_models(entity_id_name: str) -> str:
|
||||
entity_split = entity_id_name.split("::")
|
||||
if len(entity_split) == 2:
|
||||
entity_type, entity_name = entity_split
|
||||
separator = "::"
|
||||
elif len(entity_split) > 2:
|
||||
raise ValueError(f"Entity {entity_id_name} is not in the correct format")
|
||||
else:
|
||||
entity_name = entity_id_name
|
||||
separator = entity_type = ""
|
||||
|
||||
formatted_entity_type = entity_type.strip().upper()
|
||||
formatted_entity_name = (
|
||||
entity_name.strip().replace('"', "").replace("'", "").title()
|
||||
)
|
||||
|
||||
return f"{formatted_entity_type}{separator}{formatted_entity_name}"
|
||||
|
||||
|
||||
def format_relationship_id(relationship_id_name: str) -> str:
|
||||
return make_relationship_id(*split_relationship_id(relationship_id_name))
|
||||
|
||||
|
||||
def make_relationship_id(
|
||||
source_node: str, relationship_type: str, target_node: str
|
||||
) -> str:
|
||||
return (
|
||||
f"{format_entity_id(source_node)}__"
|
||||
f"{relationship_type.lower()}__"
|
||||
f"{format_entity_id(target_node)}"
|
||||
)
|
||||
|
||||
|
||||
def split_relationship_id(relationship_id_name: str) -> list[str]:
|
||||
return relationship_id_name.split("__")
|
||||
|
||||
|
||||
def format_relationship_type_id(relationship_type_id_name: str) -> str:
|
||||
return make_relationship_type_id(
|
||||
*split_relationship_type_id(relationship_type_id_name)
|
||||
)
|
||||
|
||||
|
||||
def make_relationship_type_id(
|
||||
source_node_type: str, relationship_type: str, target_node_type: str
|
||||
) -> str:
|
||||
return (
|
||||
f"{source_node_type.upper()}__"
|
||||
f"{relationship_type.lower()}__"
|
||||
f"{target_node_type.upper()}"
|
||||
)
|
||||
|
||||
|
||||
def split_relationship_type_id(relationship_type_id_name: str) -> list[str]:
|
||||
return relationship_type_id_name.split("__")
|
||||
|
||||
|
||||
def extract_relationship_type_id(relationship_id_name: str) -> str:
|
||||
source_node, relationship_type, target_node = split_relationship_id(
|
||||
relationship_id_name
|
||||
)
|
||||
return make_relationship_type_id(
|
||||
get_entity_type(source_node), relationship_type, get_entity_type(target_node)
|
||||
)
|
||||
|
||||
|
||||
def aggregate_kg_extractions(
|
||||
connector_aggregated_kg_extractions_list: list[KGAggregatedExtractions],
|
||||
) -> KGAggregatedExtractions:
|
||||
aggregated_kg_extractions = KGAggregatedExtractions(
|
||||
grounded_entities_document_ids=defaultdict(str),
|
||||
entities=defaultdict(int),
|
||||
relationships=defaultdict(lambda: defaultdict(int)),
|
||||
terms=defaultdict(int),
|
||||
attributes=defaultdict(dict),
|
||||
)
|
||||
for connector_aggregated_kg_extractions in connector_aggregated_kg_extractions_list:
|
||||
for (
|
||||
grounded_entity,
|
||||
document_id,
|
||||
) in connector_aggregated_kg_extractions.grounded_entities_document_ids.items():
|
||||
aggregated_kg_extractions.grounded_entities_document_ids[
|
||||
grounded_entity
|
||||
] = document_id
|
||||
|
||||
for entity, count in connector_aggregated_kg_extractions.entities.items():
|
||||
if entity not in aggregated_kg_extractions.entities:
|
||||
aggregated_kg_extractions.entities[entity] = count
|
||||
else:
|
||||
aggregated_kg_extractions.entities[entity] += count
|
||||
for (
|
||||
relationship,
|
||||
relationship_data,
|
||||
) in connector_aggregated_kg_extractions.relationships.items():
|
||||
for source_document_id, count in relationship_data.items():
|
||||
if relationship not in aggregated_kg_extractions.relationships:
|
||||
aggregated_kg_extractions.relationships[relationship] = defaultdict(
|
||||
int
|
||||
)
|
||||
aggregated_kg_extractions.relationships[relationship][
|
||||
source_document_id
|
||||
] += count
|
||||
for term, count in connector_aggregated_kg_extractions.terms.items():
|
||||
if term not in aggregated_kg_extractions.terms:
|
||||
aggregated_kg_extractions.terms[term] = count
|
||||
else:
|
||||
aggregated_kg_extractions.terms[term] += count
|
||||
|
||||
return aggregated_kg_extractions
|
||||
|
||||
|
||||
def kg_email_processing(email: str, kg_config_settings: KGConfigSettings) -> KGPerson:
|
||||
"""
|
||||
Process the email.
|
||||
"""
|
||||
name, company_domain = email.split("@")
|
||||
assert isinstance(company_domain, str)
|
||||
assert isinstance(kg_config_settings.KG_VENDOR_DOMAINS, list)
|
||||
assert isinstance(kg_config_settings.KG_VENDOR, str)
|
||||
|
||||
employee = any(
|
||||
domain in company_domain for domain in kg_config_settings.KG_VENDOR_DOMAINS
|
||||
)
|
||||
if employee:
|
||||
company = kg_config_settings.KG_VENDOR
|
||||
else:
|
||||
company = company_domain.capitalize()
|
||||
|
||||
return KGPerson(name=name, company=company, employee=employee)
|
||||
|
||||
|
||||
def generalize_entities(entities: list[str]) -> set[str]:
|
||||
"""
|
||||
Generalize entities to their superclass.
|
||||
"""
|
||||
return {make_entity_id(get_entity_type(entity), "*") for entity in entities}
|
||||
|
||||
|
||||
def generalize_relationships(relationships: list[str]) -> set[str]:
|
||||
"""
|
||||
Generalize relationships to their superclass.
|
||||
"""
|
||||
generalized_relationships: set[str] = set()
|
||||
for relationship in relationships:
|
||||
assert (
|
||||
len(relationship.split("__")) == 3
|
||||
), "Relationship is not in the correct format"
|
||||
source_entity, relationship_type, target_entity = split_relationship_id(
|
||||
relationship
|
||||
)
|
||||
source_general = make_entity_id(get_entity_type(source_entity), "*")
|
||||
target_general = make_entity_id(get_entity_type(target_entity), "*")
|
||||
generalized_relationships |= {
|
||||
make_relationship_id(source_general, relationship_type, target_entity),
|
||||
make_relationship_id(source_entity, relationship_type, target_general),
|
||||
make_relationship_id(source_general, relationship_type, target_general),
|
||||
}
|
||||
|
||||
return generalized_relationships
|
||||
253
backend/onyx/kg/vespa/vespa_interactions.py
Normal file
253
backend/onyx/kg/vespa/vespa_interactions.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from onyx.configs.constants import OnyxCallTypes
|
||||
from onyx.db.kg_config import KGConfigSettings
|
||||
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
|
||||
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
|
||||
from onyx.document_index.vespa.index import IndexFilters
|
||||
from onyx.kg.models import KGChunkFormat
|
||||
from onyx.kg.models import KGClassificationContent
|
||||
from onyx.kg.utils.formatting_utils import kg_email_processing
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_document_classification_content_for_kg_processing(
|
||||
document_ids: list[str],
|
||||
source: str,
|
||||
index_name: str,
|
||||
kg_config_settings: KGConfigSettings,
|
||||
batch_size: int = 8,
|
||||
num_classification_chunks: int = 3,
|
||||
entity_type: str | None = None,
|
||||
) -> Generator[list[KGClassificationContent], None, None]:
|
||||
"""
|
||||
Generates the content used for initial classification of a document from
|
||||
the first num_classification_chunks chunks.
|
||||
"""
|
||||
|
||||
classification_content_list: list[KGClassificationContent] = []
|
||||
|
||||
for i in range(0, len(document_ids), batch_size):
|
||||
batch_document_ids = document_ids[i : i + batch_size]
|
||||
for document_id in batch_document_ids:
|
||||
# ... existing code for getting chunks and processing ...
|
||||
first_num_classification_chunks: list[dict] = _get_chunks_via_visit_api(
|
||||
chunk_request=VespaChunkRequest(
|
||||
document_id=document_id,
|
||||
max_chunk_ind=num_classification_chunks - 1,
|
||||
min_chunk_ind=0,
|
||||
),
|
||||
index_name=index_name,
|
||||
filters=IndexFilters(access_control_list=None),
|
||||
field_names=[
|
||||
"document_id",
|
||||
"chunk_id",
|
||||
"title",
|
||||
"content",
|
||||
"metadata",
|
||||
"source_type",
|
||||
"primary_owners",
|
||||
"secondary_owners",
|
||||
],
|
||||
get_large_chunks=False,
|
||||
)
|
||||
|
||||
if len(first_num_classification_chunks) == 0:
|
||||
continue
|
||||
|
||||
first_num_classification_chunks = sorted(
|
||||
first_num_classification_chunks, key=lambda x: x["fields"]["chunk_id"]
|
||||
)[:num_classification_chunks]
|
||||
|
||||
classification_content = _get_classification_content_from_chunks(
|
||||
first_num_classification_chunks,
|
||||
kg_config_settings,
|
||||
)
|
||||
|
||||
metadata = first_num_classification_chunks[0]["fields"]["metadata"]
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
assert isinstance(metadata, dict) or metadata is None
|
||||
|
||||
classification_content_list.append(
|
||||
KGClassificationContent(
|
||||
document_id=document_id,
|
||||
classification_content=classification_content,
|
||||
source_type=first_num_classification_chunks[0]["fields"][
|
||||
"source_type"
|
||||
],
|
||||
source_metadata=metadata,
|
||||
entity_type=entity_type,
|
||||
)
|
||||
)
|
||||
|
||||
# Yield the batch of classification content
|
||||
if classification_content_list:
|
||||
yield classification_content_list
|
||||
classification_content_list = []
|
||||
|
||||
# Yield any remaining items
|
||||
if classification_content_list:
|
||||
yield classification_content_list
|
||||
|
||||
|
||||
def get_document_chunks_for_kg_processing(
|
||||
document_id: str,
|
||||
deep_extraction: bool,
|
||||
index_name: str,
|
||||
tenant_id: str,
|
||||
batch_size: int = 8,
|
||||
) -> Generator[list[KGChunkFormat], None, None]:
|
||||
"""
|
||||
Retrieves chunks from Vespa for the given document IDs and converts them to KGChunks.
|
||||
|
||||
Args:
|
||||
document_id (str): ID of the document to fetch chunks for
|
||||
deep_extraction (bool): Whether to perform deep extraction
|
||||
index_name (str): Name of the Vespa index
|
||||
tenant_id (str): ID of the tenant
|
||||
batch_size (int): Number of chunks to fetch per batch
|
||||
|
||||
Yields:
|
||||
list[KGChunk]: Batches of chunks ready for KG processing
|
||||
"""
|
||||
|
||||
current_batch: list[KGChunkFormat] = []
|
||||
|
||||
# get all chunks for the document
|
||||
chunks = _get_chunks_via_visit_api(
|
||||
chunk_request=VespaChunkRequest(document_id=document_id),
|
||||
index_name=index_name,
|
||||
filters=IndexFilters(access_control_list=None, tenant_id=tenant_id),
|
||||
field_names=[
|
||||
"document_id",
|
||||
"chunk_id",
|
||||
"title",
|
||||
"content",
|
||||
"metadata",
|
||||
"primary_owners",
|
||||
"secondary_owners",
|
||||
"source_type",
|
||||
"kg_entities",
|
||||
"kg_relationships",
|
||||
"kg_terms",
|
||||
],
|
||||
get_large_chunks=False,
|
||||
)
|
||||
|
||||
# Convert Vespa chunks to KGChunks
|
||||
# kg_chunks: list[KGChunkFormat] = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
fields = chunk["fields"]
|
||||
if isinstance(fields.get("metadata", {}), str):
|
||||
fields["metadata"] = json.loads(fields["metadata"])
|
||||
current_batch.append(
|
||||
KGChunkFormat(
|
||||
connector_id=None, # We may need to adjust this
|
||||
document_id=fields.get("document_id"),
|
||||
chunk_id=fields.get("chunk_id"),
|
||||
primary_owners=fields.get("primary_owners", []),
|
||||
secondary_owners=fields.get("secondary_owners", []),
|
||||
source_type=fields.get("source_type", ""),
|
||||
title=fields.get("title", ""),
|
||||
content=fields.get("content", ""),
|
||||
metadata=fields.get("metadata", {}),
|
||||
entities=fields.get("kg_entities", {}),
|
||||
relationships=fields.get("kg_relationships", {}),
|
||||
terms=fields.get("kg_terms", {}),
|
||||
deep_extraction=deep_extraction,
|
||||
)
|
||||
)
|
||||
|
||||
if len(current_batch) >= batch_size:
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
|
||||
# Yield any remaining chunks
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
|
||||
def _get_classification_content_from_call_chunks(
|
||||
first_num_classification_chunks: list[dict],
|
||||
kg_config_settings: KGConfigSettings,
|
||||
) -> str:
|
||||
"""
|
||||
Creates a KGClassificationContent object from a list of call chunks.
|
||||
"""
|
||||
|
||||
assert isinstance(kg_config_settings.KG_IGNORE_EMAIL_DOMAINS, list)
|
||||
|
||||
primary_owners = first_num_classification_chunks[0]["fields"].get(
|
||||
"primary_owners", []
|
||||
)
|
||||
secondary_owners = first_num_classification_chunks[0]["fields"].get(
|
||||
"secondary_owners", []
|
||||
)
|
||||
|
||||
company_participant_emails = set()
|
||||
account_participant_emails = set()
|
||||
|
||||
for owner in primary_owners + secondary_owners:
|
||||
kg_owner = kg_email_processing(owner, kg_config_settings)
|
||||
if any(
|
||||
domain.lower() in kg_owner.company.lower()
|
||||
for domain in kg_config_settings.KG_IGNORE_EMAIL_DOMAINS
|
||||
):
|
||||
continue
|
||||
|
||||
if kg_owner.employee:
|
||||
company_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
|
||||
else:
|
||||
account_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
|
||||
|
||||
participant_string = "\n - " + "\n - ".join(company_participant_emails)
|
||||
account_participant_string = "\n - " + "\n - ".join(account_participant_emails)
|
||||
|
||||
title_string = first_num_classification_chunks[0]["fields"]["title"]
|
||||
content_string = "\n".join(
|
||||
[
|
||||
chunk_content["fields"]["content"]
|
||||
for chunk_content in first_num_classification_chunks
|
||||
]
|
||||
)
|
||||
|
||||
classification_content = f"{title_string}\n\nVendor Participants:\n{participant_string}\n\n\
|
||||
Other Participants:\n{account_participant_string}\n\nBeginning of Call:\n{content_string}"
|
||||
|
||||
return classification_content
|
||||
|
||||
|
||||
def _get_classification_content_from_chunks(
|
||||
first_num_classification_chunks: list[dict],
|
||||
kg_config_settings: KGConfigSettings,
|
||||
) -> str:
|
||||
"""
|
||||
Creates a KGClassificationContent object from a list of chunks.
|
||||
"""
|
||||
|
||||
source_type = first_num_classification_chunks[0]["fields"]["source_type"]
|
||||
|
||||
if source_type.lower() in [call_type.value.lower() for call_type in OnyxCallTypes]:
|
||||
classification_content = _get_classification_content_from_call_chunks(
|
||||
first_num_classification_chunks,
|
||||
kg_config_settings,
|
||||
)
|
||||
|
||||
else:
|
||||
classification_content = (
|
||||
first_num_classification_chunks[0]["fields"]["title"]
|
||||
+ "\n"
|
||||
+ "\n".join(
|
||||
[
|
||||
chunk_content["fields"]["content"]
|
||||
for chunk_content in first_num_classification_chunks
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return classification_content
|
||||
@@ -41,6 +41,8 @@ from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE
|
||||
from onyx.configs.app_configs import SYSTEM_RECURSION_LIMIT
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
@@ -223,6 +225,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
)
|
||||
SqlEngine.get_engine()
|
||||
|
||||
SqlEngine.init_readonly_engine(
|
||||
pool_size=POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE,
|
||||
max_overflow=POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW,
|
||||
)
|
||||
|
||||
verify_auth = fetch_versioned_implementation(
|
||||
"onyx.auth.users", "verify_auth_setting"
|
||||
)
|
||||
|
||||
1197
backend/onyx/prompts/kg_prompts.py
Normal file
1197
backend/onyx/prompts/kg_prompts.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -78,6 +78,11 @@ class SearchToolOverrideKwargs(BaseModel):
|
||||
document_sources: list[DocumentSource] | None = None
|
||||
time_cutoff: datetime | None = None
|
||||
expanded_queries: QueryExpansions | None = None
|
||||
kg_entities: list[str] | None = None
|
||||
kg_relationships: list[str] | None = None
|
||||
kg_terms: list[str] | None = None
|
||||
kg_sources: list[str] | None = None
|
||||
kg_chunk_id_zero_only: bool | None = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -296,6 +296,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
document_sources = None
|
||||
time_cutoff = None
|
||||
expanded_queries = None
|
||||
kg_entities = None
|
||||
kg_relationships = None
|
||||
kg_terms = None
|
||||
kg_sources = None
|
||||
kg_chunk_id_zero_only = False
|
||||
if override_kwargs:
|
||||
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
|
||||
alternate_db_session = override_kwargs.alternate_db_session
|
||||
@@ -308,6 +313,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
document_sources = override_kwargs.document_sources
|
||||
time_cutoff = override_kwargs.time_cutoff
|
||||
expanded_queries = override_kwargs.expanded_queries
|
||||
kg_entities = override_kwargs.kg_entities
|
||||
kg_relationships = override_kwargs.kg_relationships
|
||||
kg_terms = override_kwargs.kg_terms
|
||||
kg_sources = override_kwargs.kg_sources
|
||||
kg_chunk_id_zero_only = override_kwargs.kg_chunk_id_zero_only or False
|
||||
|
||||
if self.selected_sections:
|
||||
yield from self._build_response_for_specified_sections(query)
|
||||
@@ -331,6 +341,19 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
# Overwrite time-cutoff should supercede existing time-cutoff, even if defined
|
||||
retrieval_options.filters.time_cutoff = time_cutoff
|
||||
|
||||
retrieval_options = retrieval_options or RetrievalDetails()
|
||||
retrieval_options.filters = retrieval_options.filters or BaseFilters()
|
||||
if kg_entities:
|
||||
retrieval_options.filters.kg_entities = kg_entities
|
||||
if kg_relationships:
|
||||
retrieval_options.filters.kg_relationships = kg_relationships
|
||||
if kg_terms:
|
||||
retrieval_options.filters.kg_terms = kg_terms
|
||||
if kg_sources:
|
||||
retrieval_options.filters.kg_sources = kg_sources
|
||||
if kg_chunk_id_zero_only:
|
||||
retrieval_options.filters.kg_chunk_id_zero_only = kg_chunk_id_zero_only
|
||||
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
|
||||
@@ -80,6 +80,7 @@ slack-sdk==3.20.2
|
||||
SQLAlchemy[mypy]==2.0.15
|
||||
starlette==0.46.1
|
||||
supervisor==4.2.5
|
||||
RapidFuzz==3.13.0
|
||||
tiktoken==0.7.0
|
||||
timeago==1.0.16
|
||||
transformers==4.49.0
|
||||
|
||||
@@ -101,6 +101,20 @@ autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
[program:celery_worker_kg_processing]
|
||||
command=celery -A onyx.background.celery.versioned_apps.kg_processing worker
|
||||
--loglevel=INFO
|
||||
--hostname=kg_processing@%%n
|
||||
-Q kg_processing
|
||||
stdout_logfile=/var/log/celery_worker_kg_processing.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
|
||||
|
||||
# Job scheduler for periodic tasks
|
||||
[program:celery_beat]
|
||||
command=celery -A onyx.background.celery.versioned_apps.beat beat
|
||||
|
||||
@@ -155,7 +155,7 @@
|
||||
"Email: fiannellib46@marriott.com\nIsDeleted: false\nLastName: Iannelli\nIsEmailBounced: false\nFirstName: Felicio\nIsPriorityRecord: false\nCleanStatus: Pending"
|
||||
],
|
||||
"semantic_identifier": "Voonder",
|
||||
"metadata": {},
|
||||
"metadata": {"object_type": "Account"},
|
||||
"primary_owners": {"email": "hagen@danswer.ai", "first_name": "Hagen", "last_name": "oneill"},
|
||||
"secondary_owners": null,
|
||||
"title": null
|
||||
|
||||
@@ -50,15 +50,21 @@ def answer_instance(
|
||||
"onyx.chat.answer.fast_gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
|
||||
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config, mocker)
|
||||
|
||||
|
||||
def _answer_fixture_impl(
|
||||
mock_llm: LLM,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
mocker: MockerFixture,
|
||||
rerank_settings: RerankingDetails | None = None,
|
||||
) -> Answer:
|
||||
mock_db_session = Mock(spec=Session)
|
||||
mock_query = Mock()
|
||||
mock_query.all.return_value = []
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
return Answer(
|
||||
prompt_builder=AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
@@ -73,7 +79,7 @@ def _answer_fixture_impl(
|
||||
raw_user_query=QUERY,
|
||||
raw_user_uploaded_files=[],
|
||||
),
|
||||
db_session=Mock(spec=Session),
|
||||
db_session=mock_db_session,
|
||||
answer_style_config=answer_style_config,
|
||||
llm=mock_llm,
|
||||
fast_llm=mock_llm,
|
||||
@@ -404,7 +410,11 @@ def test_no_slow_reranking(
|
||||
)
|
||||
)
|
||||
answer_instance = _answer_fixture_impl(
|
||||
mock_llm, answer_style_config, prompt_config, rerank_settings=rerank_settings
|
||||
mock_llm,
|
||||
answer_style_config,
|
||||
prompt_config,
|
||||
mocker,
|
||||
rerank_settings=rerank_settings,
|
||||
)
|
||||
|
||||
assert answer_instance.graph_config.inputs.rerank_settings == rerank_settings
|
||||
|
||||
@@ -42,6 +42,7 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
"onyx.chat.answer.fast_gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
|
||||
question = config["question"]
|
||||
skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"]
|
||||
|
||||
@@ -58,8 +59,14 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
mock_llm.stream = Mock()
|
||||
mock_llm.stream.return_value = [Mock()]
|
||||
|
||||
# Set up the mock database session
|
||||
mock_db_session = Mock(spec=Session)
|
||||
mock_query = Mock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.all.return_value = [] # Return empty list for KGConfig query
|
||||
|
||||
answer = Answer(
|
||||
db_session=Mock(spec=Session),
|
||||
db_session=mock_db_session,
|
||||
answer_style_config=answer_style_config,
|
||||
llm=mock_llm,
|
||||
fast_llm=mock_llm,
|
||||
|
||||
@@ -747,11 +747,17 @@ def test_salesforce_sqlite() -> None:
|
||||
sf_db.apply_schema()
|
||||
|
||||
_create_csv_with_example_data(sf_db)
|
||||
|
||||
_test_query(sf_db)
|
||||
|
||||
_test_upsert(sf_db)
|
||||
|
||||
_test_relationships(sf_db)
|
||||
|
||||
_test_account_with_children(sf_db)
|
||||
|
||||
_test_relationship_updates(sf_db)
|
||||
|
||||
_test_get_affected_parent_ids(sf_db)
|
||||
|
||||
sf_db.close()
|
||||
|
||||
@@ -183,6 +183,8 @@ services:
|
||||
- POSTGRES_HOST=relational_db
|
||||
- POSTGRES_USER=${POSTGRES_USER:-}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
|
||||
- DB_READONLY_USER=${DB_READONLY_USER:-}
|
||||
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-}
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-}
|
||||
@@ -384,6 +386,8 @@ services:
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
- DB_READONLY_USER=${DB_READONLY_USER:-}
|
||||
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
|
||||
@@ -148,6 +148,8 @@ services:
|
||||
- POSTGRES_HOST=relational_db
|
||||
- POSTGRES_USER=${POSTGRES_USER:-}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
|
||||
- DB_READONLY_USER=${DB_READONLY_USER:-}
|
||||
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-}
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
@@ -330,6 +332,8 @@ services:
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
- DB_READONLY_USER=${DB_READONLY_USER:-}
|
||||
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
|
||||
@@ -166,6 +166,8 @@ services:
|
||||
- POSTGRES_HOST=relational_db
|
||||
- POSTGRES_USER=${POSTGRES_USER:-}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
|
||||
- DB_READONLY_USER=${DB_READONLY_USER:-}
|
||||
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-}
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- VESPA_HOST=index
|
||||
@@ -357,6 +359,8 @@ services:
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
- DB_READONLY_USER=${DB_READONLY_USER:-}
|
||||
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
|
||||
@@ -154,6 +154,8 @@ services:
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
- DB_READONLY_USER=${DB_READONLY_USER:-}
|
||||
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
|
||||
ports:
|
||||
- "5432"
|
||||
volumes:
|
||||
|
||||
@@ -55,3 +55,10 @@ SESSION_EXPIRE_TIME_SECONDS=604800
|
||||
# Default values here are what Postgres uses by default, feel free to change.
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=password
|
||||
|
||||
|
||||
# Default values here for the read-only user for the knowledge graph and other future read-only purposes.
|
||||
# Please change password!
|
||||
DB_READONLY_USER=db_readonly_user
|
||||
DB_READONLY_PASSWORD=password
|
||||
|
||||
|
||||
@@ -40,6 +40,16 @@ spec:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: postgres_password
|
||||
- name: DB_READONLY_USER
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: DB_READONLY_user
|
||||
- name: DB_READONLY_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: DB_READONLY_password
|
||||
args: ["-c", "max_connections=250"]
|
||||
ports:
|
||||
- containerPort: 5432
|
||||
|
||||
@@ -532,7 +532,7 @@ export function AssistantEditor({
|
||||
|
||||
// if disable_retrieval is set, set num_chunks to 0
|
||||
// to tell the backend to not fetch any documents
|
||||
const numChunks = searchToolEnabled ? values.num_chunks || 10 : 0;
|
||||
const numChunks = searchToolEnabled ? values.num_chunks || 25 : 0;
|
||||
const starterMessages = values.starter_messages
|
||||
.filter(
|
||||
(message: { message: string }) => message.message.trim() !== ""
|
||||
|
||||
Reference in New Issue
Block a user