Compare commits

...

6 Commits

Author SHA1 Message Date
joachim-danswer
0e38898c7a fixed migration 2025-06-06 16:05:01 -07:00
joachim-danswer
ce6a597eca mt-test observations/fixes 2025-06-06 13:42:27 -07:00
joachim-danswer
d251ba40ae update migration 2025-06-06 08:52:22 -07:00
joachim-danswer
26395d81c9 path fix 2025-06-05 22:16:39 -07:00
joachim-danswer
e1a3e11ec9 github error correction 1 2025-06-05 22:11:14 -07:00
joachim-danswer
e013711664 Initial Knowledge Graph Implementation, including:
- private schema upgrade
 - migration of tenant schema
 - extraction & clustering for KG
 - Graph for KG Answers
2025-06-05 19:46:50 -07:00
93 changed files with 12589 additions and 75 deletions

View File

@@ -147,6 +147,8 @@ jobs:
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
@@ -245,6 +247,8 @@ jobs:
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \

View File

@@ -183,6 +183,8 @@ jobs:
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \

View File

@@ -0,0 +1,690 @@
"""create knowledge graph tables
Revision ID: 495cb26ce93e
Revises: ca04500b9ee8
Create Date: 2025-03-19 08:51:14.341989
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy import text
from datetime import datetime, timedelta
from onyx.configs.app_configs import DB_READONLY_USER
from onyx.configs.app_configs import DB_READONLY_PASSWORD
from shared_configs.configs import MULTI_TENANT
# revision identifiers, used by Alembic.
revision = "495cb26ce93e"
down_revision = "ca04500b9ee8"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create a new permission-less user to be later used for knowledge graph queries.
# The user will later get temporary read privileges for a specific view that will be
# ad hoc generated specific to a knowledge graph query.
#
# Note: in order for the migration to run, the DB_READONLY_USER and DB_READONLY_PASSWORD
# environment variables MUST be set. Otherwise, an exception will be raised.
print("MULTI_TENANT: ", MULTI_TENANT)
if not MULTI_TENANT:
print("Single tenant mode")
# Enable pg_trgm extension if not already enabled
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
# Create read-only db user here only in single tenant mode. For multi-tenant mode,
# the user is created in the alembic_tenants migration.
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
op.execute(
text(
f"""
DO $$
BEGIN
-- Check if the read-only user already exists
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
-- Create the read-only user with the specified password
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
-- First revoke all privileges to ensure a clean slate
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
-- Grant only the CONNECT privilege to allow the user to connect to the database
-- but not perform any operations without additional specific grants
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)
# Grant usage on current schema to readonly user
op.execute(
text(
f"""
DO $$
BEGIN
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
EXECUTE format('GRANT USAGE ON SCHEMA %I TO %I', current_schema(), '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)
op.create_table(
"kg_config",
sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True),
sa.Column("kg_variable_name", sa.String(), nullable=False, index=True),
sa.Column("kg_variable_values", postgresql.ARRAY(sa.String()), nullable=False),
sa.UniqueConstraint("kg_variable_name", name="uq_kg_config_variable_name"),
)
# Insert initial data into kg_config table
op.bulk_insert(
sa.table(
"kg_config",
sa.column("kg_variable_name", sa.String),
sa.column("kg_variable_values", postgresql.ARRAY(sa.String)),
),
[
{"kg_variable_name": "KG_EXPOSED", "kg_variable_values": ["false"]},
{"kg_variable_name": "KG_ENABLED", "kg_variable_values": ["false"]},
{"kg_variable_name": "KG_VENDOR", "kg_variable_values": []},
{"kg_variable_name": "KG_VENDOR_DOMAINS", "kg_variable_values": []},
{"kg_variable_name": "KG_IGNORE_EMAIL_DOMAINS", "kg_variable_values": []},
{
"kg_variable_name": "KG_EXTRACTION_IN_PROGRESS",
"kg_variable_values": ["false"],
},
{
"kg_variable_name": "KG_CLUSTERING_IN_PROGRESS",
"kg_variable_values": ["false"],
},
{
"kg_variable_name": "KG_COVERAGE_START",
"kg_variable_values": [
(datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d")
],
},
{"kg_variable_name": "KG_MAX_COVERAGE_DAYS", "kg_variable_values": ["90"]},
{
"kg_variable_name": "KG_MAX_PARENT_RECURSION_DEPTH",
"kg_variable_values": ["2"],
},
],
)
op.create_table(
"kg_entity_type",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("description", sa.String(), nullable=True),
sa.Column("grounding", sa.String(), nullable=False),
sa.Column(
"attributes",
postgresql.JSONB,
nullable=False,
server_default="{}",
),
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
sa.Column("active", sa.Boolean(), nullable=False, default=False),
sa.Column("deep_extraction", sa.Boolean(), nullable=False, default=False),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.Column("grounded_source_name", sa.String(), nullable=True),
sa.Column("entity_values", postgresql.ARRAY(sa.String()), nullable=True),
sa.Column(
"clustering",
postgresql.JSONB,
nullable=False,
server_default="{}",
),
)
# Create KGRelationshipType table
op.create_table(
"kg_relationship_type",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("name", sa.String(), nullable=False, index=True),
sa.Column(
"source_entity_type_id_name", sa.String(), nullable=False, index=True
),
sa.Column(
"target_entity_type_id_name", sa.String(), nullable=False, index=True
),
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
sa.Column("type", sa.String(), nullable=False, index=True),
sa.Column("active", sa.Boolean(), nullable=False, default=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.Column(
"clustering",
postgresql.JSONB,
nullable=False,
server_default="{}",
),
sa.ForeignKeyConstraint(
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
),
sa.ForeignKeyConstraint(
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
),
)
# Create KGRelationshipTypeExtractionStaging table
op.create_table(
"kg_relationship_type_extraction_staging",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("name", sa.String(), nullable=False, index=True),
sa.Column(
"source_entity_type_id_name", sa.String(), nullable=False, index=True
),
sa.Column(
"target_entity_type_id_name", sa.String(), nullable=False, index=True
),
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
sa.Column("type", sa.String(), nullable=False, index=True),
sa.Column("active", sa.Boolean(), nullable=False, default=True),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.Column(
"clustering",
postgresql.JSONB,
nullable=False,
server_default="{}",
),
sa.Column("transferred", sa.Boolean(), nullable=False, server_default="false"),
sa.ForeignKeyConstraint(
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
),
sa.ForeignKeyConstraint(
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
),
)
# Create KGEntity table
op.create_table(
"kg_entity",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("name", sa.String(), nullable=False, index=True),
sa.Column("entity_class", sa.String(), nullable=True, index=True),
sa.Column("entity_subtype", sa.String(), nullable=True, index=True),
sa.Column("entity_key", sa.String(), nullable=True, index=True),
sa.Column("name_trigrams", postgresql.ARRAY(sa.String(3)), nullable=True),
sa.Column("document_id", sa.String(), nullable=True, index=True),
sa.Column(
"alternative_names",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
sa.Column("description", sa.String(), nullable=True),
sa.Column(
"keywords",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
sa.Column(
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
),
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("attributes", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
sa.ForeignKeyConstraint(["document_id"], ["document.id"]),
sa.UniqueConstraint(
"name",
"entity_type_id_name",
"document_id",
name="uq_kg_entity_name_type_doc",
),
)
op.create_index("ix_entity_type_acl", "kg_entity", ["entity_type_id_name", "acl"])
op.create_index(
"ix_entity_name_search", "kg_entity", ["name", "entity_type_id_name"]
)
# Create KGEntityExtractionStaging table
op.create_table(
"kg_entity_extraction_staging",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("name", sa.String(), nullable=False, index=True),
sa.Column("document_id", sa.String(), nullable=True, index=True),
sa.Column(
"alternative_names",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
sa.Column("description", sa.String(), nullable=True),
sa.Column(
"keywords",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
sa.Column(
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
),
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("attributes", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("transferred_id_name", sa.String(), nullable=True, default=None),
sa.Column("entity_class", sa.String(), nullable=True, index=True),
sa.Column("entity_key", sa.String(), nullable=True, index=True),
sa.Column("entity_subtype", sa.String(), nullable=True, index=True),
sa.Column("parent_key", sa.String(), nullable=True, index=True),
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
sa.ForeignKeyConstraint(["document_id"], ["document.id"]),
)
op.create_index(
"ix_entity_extraction_staging_acl",
"kg_entity_extraction_staging",
["entity_type_id_name", "acl"],
)
op.create_index(
"ix_entity_extraction_staging_name_search",
"kg_entity_extraction_staging",
["name", "entity_type_id_name"],
)
# Create KGRelationship table
op.create_table(
"kg_relationship",
sa.Column("id_name", sa.String(), nullable=False, index=True),
sa.Column("source_node", sa.String(), nullable=False, index=True),
sa.Column("target_node", sa.String(), nullable=False, index=True),
sa.Column("source_node_type", sa.String(), nullable=False, index=True),
sa.Column("target_node_type", sa.String(), nullable=False, index=True),
sa.Column("source_document", sa.String(), nullable=True, index=True),
sa.Column("type", sa.String(), nullable=False, index=True),
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.ForeignKeyConstraint(["source_node"], ["kg_entity.id_name"]),
sa.ForeignKeyConstraint(["target_node"], ["kg_entity.id_name"]),
sa.ForeignKeyConstraint(["source_node_type"], ["kg_entity_type.id_name"]),
sa.ForeignKeyConstraint(["target_node_type"], ["kg_entity_type.id_name"]),
sa.ForeignKeyConstraint(["source_document"], ["document.id"]),
sa.ForeignKeyConstraint(
["relationship_type_id_name"], ["kg_relationship_type.id_name"]
),
sa.UniqueConstraint(
"source_node",
"target_node",
"type",
name="uq_kg_relationship_source_target_type",
),
sa.PrimaryKeyConstraint("id_name", "source_document"),
)
op.create_index(
"ix_kg_relationship_nodes", "kg_relationship", ["source_node", "target_node"]
)
# Create KGRelationshipExtractionStaging table
op.create_table(
"kg_relationship_extraction_staging",
sa.Column("id_name", sa.String(), nullable=False, index=True),
sa.Column("source_node", sa.String(), nullable=False, index=True),
sa.Column("target_node", sa.String(), nullable=False, index=True),
sa.Column("source_node_type", sa.String(), nullable=False, index=True),
sa.Column("target_node_type", sa.String(), nullable=False, index=True),
sa.Column("source_document", sa.String(), nullable=True, index=True),
sa.Column("type", sa.String(), nullable=False, index=True),
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False),
sa.Column("transferred", sa.Boolean(), nullable=False, server_default="false"),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.ForeignKeyConstraint(
["source_node"], ["kg_entity_extraction_staging.id_name"]
),
sa.ForeignKeyConstraint(
["target_node"], ["kg_entity_extraction_staging.id_name"]
),
sa.ForeignKeyConstraint(["source_node_type"], ["kg_entity_type.id_name"]),
sa.ForeignKeyConstraint(["target_node_type"], ["kg_entity_type.id_name"]),
sa.ForeignKeyConstraint(["source_document"], ["document.id"]),
sa.ForeignKeyConstraint(
["relationship_type_id_name"],
["kg_relationship_type_extraction_staging.id_name"],
),
sa.UniqueConstraint(
"source_node",
"target_node",
"type",
name="uq_kg_relationship_extraction_staging_source_target_type",
),
sa.PrimaryKeyConstraint("id_name", "source_document"),
)
op.create_index(
"ix_kg_relationship_extraction_staging_nodes",
"kg_relationship_extraction_staging",
["source_node", "target_node"],
)
# Create KGTerm table
op.create_table(
"kg_term",
sa.Column("id_term", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column(
"entity_types",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
)
op.create_index("ix_search_term_entities", "kg_term", ["entity_types"])
op.create_index("ix_search_term_term", "kg_term", ["id_term"])
op.add_column(
"document",
sa.Column("kg_stage", sa.String(), nullable=True, index=True),
)
op.add_column(
"document",
sa.Column("kg_processing_time", sa.DateTime(timezone=True), nullable=True),
)
op.add_column(
"connector",
sa.Column(
"kg_processing_enabled",
sa.Boolean(),
nullable=True,
server_default="false",
),
)
op.add_column(
"connector",
sa.Column(
"kg_coverage_days",
sa.Integer(),
nullable=True,
server_default=None,
),
)
# Create GIN index for clustering and normalization
op.execute(
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_kg_entity_clustering_trigrams "
"ON kg_entity USING GIN (name public.gin_trgm_ops)"
)
op.execute(
"CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_kg_entity_normalization_trigrams "
"ON kg_entity USING GIN (name_trigrams)"
)
# Create kg_entity trigger to update kg_entity.name and its trigrams
alphanum_pattern = r"[^a-z0-9]+"
truncate_length = 1000
function = "update_kg_entity_name"
op.execute(
text(
f"""
CREATE OR REPLACE FUNCTION {function}()
RETURNS TRIGGER AS $$
DECLARE
name text;
cleaned_name text;
BEGIN
-- Set name to semantic_id if document_id is not NULL
IF NEW.document_id IS NOT NULL THEN
SELECT lower(semantic_id) INTO name
FROM document
WHERE id = NEW.document_id;
ELSE
name = lower(NEW.name);
END IF;
-- Clean name and truncate if too long
cleaned_name = regexp_replace(
name,
'{alphanum_pattern}', '', 'g'
);
IF length(cleaned_name) > {truncate_length} THEN
cleaned_name = left(cleaned_name, {truncate_length});
END IF;
-- Set name and name trigrams
NEW.name = name;
NEW.name_trigrams = public.show_trgm(cleaned_name);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
"""
)
)
trigger = f"{function}_trigger"
op.execute(f"DROP TRIGGER IF EXISTS {trigger} ON kg_entity")
op.execute(
f"""
CREATE TRIGGER {trigger}
BEFORE INSERT OR UPDATE OF name
ON kg_entity
FOR EACH ROW
EXECUTE FUNCTION {function}();
"""
)
# Create kg_entity trigger to update kg_entity.name and its trigrams
function = "update_kg_entity_name_from_doc"
op.execute(
text(
f"""
CREATE OR REPLACE FUNCTION {function}()
RETURNS TRIGGER AS $$
DECLARE
doc_name text;
cleaned_name text;
BEGIN
doc_name = lower(NEW.semantic_id);
-- Clean name and truncate if too long
cleaned_name = regexp_replace(
doc_name,
'{alphanum_pattern}', '', 'g'
);
IF length(cleaned_name) > {truncate_length} THEN
cleaned_name = left(cleaned_name, {truncate_length});
END IF;
-- Set name and name trigrams for all entities referencing this document
UPDATE kg_entity
SET
name = doc_name,
name_trigrams = public.show_trgm(cleaned_name)
WHERE document_id = NEW.id;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
"""
)
)
trigger = f"{function}_trigger"
op.execute(f"DROP TRIGGER IF EXISTS {trigger} ON document")
op.execute(
f"""
CREATE TRIGGER {trigger}
AFTER UPDATE OF semantic_id
ON document
FOR EACH ROW
EXECUTE FUNCTION {function}();
"""
)
def downgrade() -> None:
# Drop all views that start with 'kg_'
op.execute(
"""
DO $$
DECLARE
view_name text;
BEGIN
FOR view_name IN
SELECT c.relname
FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind = 'v'
AND n.nspname = current_schema()
AND c.relname LIKE 'kg_relationships_with_access%'
LOOP
EXECUTE 'DROP VIEW IF EXISTS ' || quote_ident(view_name);
END LOOP;
END $$;
"""
)
op.execute(
"""
DO $$
DECLARE
view_name text;
BEGIN
FOR view_name IN
SELECT c.relname
FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind = 'v'
AND n.nspname = current_schema()
AND c.relname LIKE 'allowed_docs%'
LOOP
EXECUTE 'DROP VIEW IF EXISTS ' || quote_ident(view_name);
END LOOP;
END $$;
"""
)
for table, function in (
("kg_entity", "update_kg_entity_name"),
("document", "update_kg_entity_name_from_doc"),
):
op.execute(f"DROP TRIGGER IF EXISTS {function}_trigger ON {table}")
op.execute(f"DROP FUNCTION IF EXISTS {function}()")
# Drop index
op.execute("COMMIT") # Commit to allow CONCURRENTLY
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_kg_entity_clustering_trigrams")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_kg_entity_normalization_trigrams")
# Drop tables in reverse order of creation to handle dependencies
op.drop_table("kg_term")
op.drop_table("kg_relationship")
op.drop_table("kg_entity")
op.drop_table("kg_relationship_type")
op.drop_table("kg_relationship_extraction_staging")
op.drop_table("kg_relationship_type_extraction_staging")
op.drop_table("kg_entity_extraction_staging")
op.drop_table("kg_entity_type")
op.drop_column("connector", "kg_processing_enabled")
op.drop_column("connector", "kg_coverage_days")
op.drop_column("document", "kg_stage")
op.drop_column("document", "kg_processing_time")
op.drop_table("kg_config")
if not MULTI_TENANT:
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
# the user is dropped in the alembic_tenants migration.
op.execute(
text(
f"""
DO $$
BEGIN
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
-- First revoke all privileges from the database
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
-- Then drop the user
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)
else:
op.execute(
text(
f"""
DO $$
BEGIN
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
EXECUTE format('REVOKE ALL ON SCHEMA %I FROM %I', current_schema(), '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)

View File

@@ -0,0 +1,79 @@
"""add_db_readonly_user
Revision ID: 3b9f09038764
Revises: 3b45e0018bf1
Create Date: 2025-05-11 11:05:11.436977
"""
from sqlalchemy import text
from alembic import op
from onyx.configs.app_configs import DB_READONLY_PASSWORD
from onyx.configs.app_configs import DB_READONLY_USER
from shared_configs.configs import MULTI_TENANT
# revision identifiers, used by Alembic.
revision = "3b9f09038764"
down_revision = "3b45e0018bf1"
branch_labels = None
depends_on = None
def upgrade() -> None:
if MULTI_TENANT:
# Enable pg_trgm extension if not already enabled
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
# the user is created in the standard migration.
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
op.execute(
text(
f"""
DO $$
BEGIN
-- Check if the read-only user already exists
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
-- Create the read-only user with the specified password
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
-- First revoke all privileges to ensure a clean slate
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
-- Grant only the CONNECT privilege to allow the user to connect to the database
-- but not perform any operations without additional specific grants
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)
def downgrade() -> None:
if MULTI_TENANT:
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
# the user is dropped in the alembic_tenants migration.
op.execute(
text(
f"""
DO $$
BEGIN
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
-- First revoke all privileges from the database
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
-- Then revoke all privileges from the public schema
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
-- Then drop the user
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)

View File

@@ -150,14 +150,12 @@ def research_object_source(
),
)
]
# fast_llm = graph_config.tooling.fast_llm
primary_llm = graph_config.tooling.primary_llm
llm = primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
llm.invoke,
primary_llm.invoke,
prompt=msg,
timeout_override=30,
max_tokens=300,

View File

@@ -71,15 +71,12 @@ def consolidate_object_research(
),
)
]
graph_config.tooling.primary_llm
# fast_llm = graph_config.tooling.fast_llm
primary_llm = graph_config.tooling.primary_llm
llm = primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
llm.invoke,
primary_llm.invoke,
prompt=msg,
timeout_override=30,
max_tokens=300,

View File

@@ -0,0 +1,70 @@
from collections.abc import Hashable
from datetime import datetime
from enum import Enum
from langgraph.types import Send
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
class KGAnalysisPath(str, Enum):
PROCESS_KG_ONLY_ANSWERS = "process_kg_only_answers"
CONSTRUCT_DEEP_SEARCH_FILTERS = "construct_deep_search_filters"
def simple_vs_search(
state: MainState,
) -> str:
identified_strategy = state.updated_strategy or state.strategy
if (
identified_strategy == KGAnswerStrategy.DEEP
or state.search_type == KGSearchType.SEARCH
):
return KGAnalysisPath.CONSTRUCT_DEEP_SEARCH_FILTERS.value
else:
return KGAnalysisPath.PROCESS_KG_ONLY_ANSWERS.value
def research_individual_object(
state: MainState,
) -> list[Send | Hashable] | str:
edge_start_time = datetime.now()
assert state.div_con_entities is not None
assert state.broken_down_question is not None
assert state.vespa_filter_results is not None
if (
state.search_type == KGSearchType.SQL
and state.strategy == KGAnswerStrategy.DEEP
):
return [
Send(
"process_individual_deep_search",
ResearchObjectInput(
research_nr=research_nr + 1,
entity=entity,
broken_down_question=state.broken_down_question,
vespa_filter_results=state.vespa_filter_results,
source_division=state.source_division,
source_entity_filters=state.source_filters,
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
step_results=[],
),
)
for research_nr, entity in enumerate(state.div_con_entities)
]
elif state.search_type == KGSearchType.SEARCH:
return "filtered_search"
else:
raise ValueError(
f"Invalid combination of search type: {state.search_type} and strategy: {state.strategy}"
)

View File

@@ -0,0 +1,143 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.kb_search.conditional_edges import (
research_individual_object,
)
from onyx.agents.agent_search.kb_search.conditional_edges import simple_vs_search
from onyx.agents.agent_search.kb_search.nodes.a1_extract_ert import extract_ert
from onyx.agents.agent_search.kb_search.nodes.a2_analyze import analyze
from onyx.agents.agent_search.kb_search.nodes.a3_generate_simple_sql import (
generate_simple_sql,
)
from onyx.agents.agent_search.kb_search.nodes.b1_construct_deep_search_filters import (
construct_deep_search_filters,
)
from onyx.agents.agent_search.kb_search.nodes.b2p_process_individual_deep_search import (
process_individual_deep_search,
)
from onyx.agents.agent_search.kb_search.nodes.b2s_filtered_search import filtered_search
from onyx.agents.agent_search.kb_search.nodes.b3_consolidate_individual_deep_search import (
consolidate_individual_deep_search,
)
from onyx.agents.agent_search.kb_search.nodes.c1_process_kg_only_answers import (
process_kg_only_answers,
)
from onyx.agents.agent_search.kb_search.nodes.d1_generate_answer import generate_answer
from onyx.agents.agent_search.kb_search.nodes.d2_logging_node import log_data
from onyx.agents.agent_search.kb_search.states import MainInput
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def kb_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the knowledge graph search process.
"""
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
### Add nodes ###
graph.add_node(
"extract_ert",
extract_ert,
)
graph.add_node(
"generate_simple_sql",
generate_simple_sql,
)
graph.add_node(
"filtered_search",
filtered_search,
)
graph.add_node(
"analyze",
analyze,
)
graph.add_node(
"generate_answer",
generate_answer,
)
graph.add_node(
"log_data",
log_data,
)
graph.add_node(
"construct_deep_search_filters",
construct_deep_search_filters,
)
graph.add_node(
"process_individual_deep_search",
process_individual_deep_search,
)
graph.add_node(
"consolidate_individual_deep_search",
consolidate_individual_deep_search,
)
graph.add_node("process_kg_only_answers", process_kg_only_answers)
### Add edges ###
graph.add_edge(start_key=START, end_key="extract_ert")
graph.add_edge(
start_key="extract_ert",
end_key="analyze",
)
graph.add_edge(
start_key="analyze",
end_key="generate_simple_sql",
)
graph.add_conditional_edges("generate_simple_sql", simple_vs_search)
graph.add_edge(start_key="process_kg_only_answers", end_key="generate_answer")
graph.add_conditional_edges(
source="construct_deep_search_filters",
path=research_individual_object,
path_map=["process_individual_deep_search", "filtered_search"],
)
graph.add_edge(
start_key="process_individual_deep_search",
end_key="consolidate_individual_deep_search",
)
graph.add_edge(
start_key="consolidate_individual_deep_search", end_key="generate_answer"
)
graph.add_edge(
start_key="filtered_search",
end_key="generate_answer",
)
graph.add_edge(
start_key="generate_answer",
end_key="log_data",
)
graph.add_edge(
start_key="log_data",
end_key=END,
)
return graph

View File

@@ -0,0 +1,409 @@
import re
from time import sleep
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
from onyx.agents.agent_search.kb_search.models import KGExpandedGraphObjects
from onyx.agents.agent_search.kb_search.states import SubQuestionAnswerResults
from onyx.agents.agent_search.kb_search.step_definitions import STEP_DESCRIPTIONS
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import LlmDoc
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.db.document import get_kg_doc_info_for_entity_name
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import get_document_id_for_entity
from onyx.db.entities import get_entity_name
from onyx.db.entity_type import get_entity_types
from onyx.kg.utils.formatting_utils import make_entity_id
from onyx.kg.utils.formatting_utils import split_relationship_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _check_entities_disconnected(
current_entities: list[str], current_relationships: list[str]
) -> bool:
"""
Check if all entities in current_entities are disconnected via the given relationships.
Relationships are in the format: source_entity__relationship_name__target_entity
Args:
current_entities: List of entity IDs to check connectivity for
current_relationships: List of relationships in format source__relationship__target
Returns:
bool: True if all entities are disconnected, False otherwise
"""
if not current_entities:
return True
# Create a graph representation using adjacency list
graph: dict[str, set[str]] = {entity: set() for entity in current_entities}
# Build the graph from relationships
for relationship in current_relationships:
try:
source, _, target = split_relationship_id(relationship)
if source in graph and target in graph:
graph[source].add(target)
# Add reverse edge to capture that we do also have a relationship in the other direction,
# albeit not quite the same one.
graph[target].add(source)
except ValueError:
raise ValueError(f"Invalid relationship format: {relationship}")
# Use BFS to check if all entities are connected
visited: set[str] = set()
start_entity = current_entities[0]
def _bfs(start: str) -> None:
queue = [start]
visited.add(start)
while queue:
current = queue.pop(0)
for neighbor in graph[current]:
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
# Start BFS from the first entity
_bfs(start_entity)
logger.debug(f"Number of visited entities: {len(visited)}")
# Check if all current_entities are in visited
return not all(entity in visited for entity in current_entities)
def create_minimal_connected_query_graph(
entities: list[str], relationships: list[str], max_depth: int = 1
) -> KGExpandedGraphObjects:
"""
TODO: Implement this. For now we'll trust the SQL generation to do the right thing.
Return the original entities and relationships.
"""
return KGExpandedGraphObjects(entities=entities, relationships=relationships)
def stream_write_step_description(
writer: StreamWriter, step_nr: int, level: int = 0
) -> None:
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=STEP_DESCRIPTIONS[step_nr].description,
level=level,
level_question_num=step_nr,
),
writer,
)
# Give the frontend a brief moment to catch up
sleep(0.2)
def stream_write_step_activities(
writer: StreamWriter, step_nr: int, level: int = 0
) -> None:
for activity_nr, activity in enumerate(STEP_DESCRIPTIONS[step_nr].activities):
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query=activity,
level=level,
level_question_num=step_nr,
query_id=activity_nr + 1,
),
writer,
)
def stream_write_step_activity_explicit(
writer: StreamWriter, step_nr: int, query_id: int, activity: str, level: int = 0
) -> None:
for activity in STEP_DESCRIPTIONS[step_nr].activities:
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query=activity,
level=level,
level_question_num=step_nr,
query_id=query_id,
),
writer,
)
def stream_write_step_answer_explicit(
writer: StreamWriter, step_nr: int, answer: str, level: int = 0
) -> None:
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=answer,
level=level,
level_question_num=step_nr,
answer_type="agent_sub_answer",
),
writer,
)
def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
for step_nr, step_detail in STEP_DESCRIPTIONS.items():
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=step_detail.description,
level=level,
level_question_num=step_nr,
),
writer,
)
for step_nr in STEP_DESCRIPTIONS.keys():
write_custom_event(
"stream_finished",
StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=level,
level_question_num=step_nr,
),
writer,
)
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=0,
)
write_custom_event("stream_finished", stop_event, writer)
def stream_close_step_answer(
writer: StreamWriter, step_nr: int, level: int = 0
) -> None:
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_ANSWER,
level=level,
level_question_num=step_nr,
)
write_custom_event("stream_finished", stop_event, writer)
def stream_write_close_steps(writer: StreamWriter, level: int = 0) -> None:
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=level,
)
write_custom_event("stream_finished", stop_event, writer)
def stream_write_close_main_answer(writer: StreamWriter, level: int = 0) -> None:
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.MAIN_ANSWER,
level=level,
level_question_num=0,
)
write_custom_event("stream_finished", stop_event, writer)
def stream_write_main_answer_token(
writer: StreamWriter, token: str, level: int = 0, level_question_num: int = 0
) -> None:
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=token, # No need to add space as tokenizer handles this
level=level,
level_question_num=level_question_num,
answer_type="agent_level_answer",
),
writer,
)
def get_doc_information_for_entity(entity_id_name: str) -> KGEntityDocInfo:
"""
Get document information for an entity, including its semantic name and document details.
"""
if "::" not in entity_id_name:
return KGEntityDocInfo(
doc_id=None,
doc_semantic_id=None,
doc_link=None,
semantic_entity_name=entity_id_name,
semantic_linked_entity_name=entity_id_name,
)
entity_type, entity_name = map(str.strip, entity_id_name.split("::", 1))
with get_session_with_current_tenant() as db_session:
entity_document_id = get_document_id_for_entity(db_session, entity_id_name)
if entity_document_id:
return get_kg_doc_info_for_entity_name(
db_session, entity_document_id, entity_type
)
else:
entity_actual_name = get_entity_name(db_session, entity_id_name)
return KGEntityDocInfo(
doc_id=None,
doc_semantic_id=None,
doc_link=None,
semantic_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
semantic_linked_entity_name=f"{entity_type} {entity_actual_name or entity_id_name}",
)
def rename_entities_in_answer(answer: str) -> str:
"""
Process entity references in the answer string by:
1. Extracting all strings matching <str>:<str> or <str>: <str> patterns
2. Looking up these references in the entity table
3. Replacing valid references with their corresponding values
Args:
answer: The input string containing potential entity references
Returns:
str: The processed string with entity references replaced
"""
logger.debug(f"Input answer: {answer}")
# Clean up any spaces around ::
answer = re.sub(r"::\s+", "::", answer)
logger.debug(f"After cleaning spaces: {answer}")
# Pattern to match entity_type::entity_name, with optional quotes
pattern = r"(?:')?([a-zA-Z0-9-]+)::([a-zA-Z0-9]+)(?:')?"
logger.debug(f"Using pattern: {pattern}")
matches = list(re.finditer(pattern, answer))
logger.debug(f"Found {len(matches)} matches")
# get active entity types
with get_session_with_current_tenant() as db_session:
active_entity_types = [
x.id_name for x in get_entity_types(db_session, active=True)
]
logger.debug(f"Active entity types: {active_entity_types}")
# Create dictionary for processed references
processed_refs = {}
for match in matches:
entity_type = match.group(1).upper().strip()
entity_name = match.group(2).strip()
potential_entity_id_name = make_entity_id(entity_type, entity_name)
logger.debug(f"Processing entity: {potential_entity_id_name}")
if entity_type not in active_entity_types:
logger.debug(f"Entity type {entity_type} not in active types")
continue
replacement_candidate = get_doc_information_for_entity(potential_entity_id_name)
if replacement_candidate.doc_id:
# Store both the original match and the entity_id_name for replacement
processed_refs[match.group(0)] = (
replacement_candidate.semantic_linked_entity_name
)
logger.debug(
f"Added replacement: {match.group(0)} -> {replacement_candidate.semantic_linked_entity_name}"
)
else:
processed_refs[match.group(0)] = replacement_candidate.semantic_entity_name
logger.debug(
f"Added replacement: {match.group(0)} -> {replacement_candidate.semantic_entity_name}"
)
# Replace all references in the answer
for ref, replacement in processed_refs.items():
answer = answer.replace(ref, replacement)
logger.debug(f"Replaced {ref} with {replacement}")
return answer
def build_document_context(
document: InferenceSection | LlmDoc, document_number: int
) -> str:
"""
Build a context string for a document.
"""
metadata_list: list[str] = []
document_content: str | None = None
info_source: InferenceChunk | LlmDoc | None = None
info_content: str | None = None
if isinstance(document, InferenceSection):
info_source = document.center_chunk
info_content = document.combined_content
elif isinstance(document, LlmDoc):
info_source = document
info_content = document.content
for key, value in info_source.metadata.items():
metadata_list.append(f" - {key}: {value}")
if metadata_list:
metadata_str = "- Document Metadata:\n" + "\n".join(metadata_list)
else:
metadata_str = ""
# Construct document header with number and semantic identifier
doc_header = f"Document {str(document_number)}: {info_source.semantic_identifier}"
# Combine all parts with proper spacing
document_content = f"{doc_header}\n\n{metadata_str}\n\n{info_content}"
return document_content
def get_near_empty_step_results(
step_number: int,
step_answer: str,
verified_reranked_documents: list[InferenceSection] = [],
) -> SubQuestionAnswerResults:
"""
Get near-empty step results from a list of step results.
"""
return SubQuestionAnswerResults(
question=STEP_DESCRIPTIONS[step_number].description,
question_id="0_" + str(step_number),
answer=step_answer,
verified_high_quality=True,
sub_query_retrieval_results=[],
verified_reranked_documents=verified_reranked_documents,
context_documents=[],
cited_documents=[],
sub_question_retrieval_stats=AgentChunkRetrievalStats(
verified_count=None,
verified_avg_scores=None,
rejected_count=None,
rejected_avg_scores=None,
verified_doc_chunk_ids=[],
dismissed_doc_chunk_ids=[],
),
)

View File

@@ -0,0 +1,49 @@
from pydantic import BaseModel
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import YesNoEnum
class KGQuestionEntityExtractionResult(BaseModel):
entities: list[str]
terms: list[str]
time_filter: str | None
class KGAnswerApproach(BaseModel):
search_type: KGSearchType
search_strategy: KGAnswerStrategy
format: KGAnswerFormat
broken_down_question: str | None = None
divide_and_conquer: YesNoEnum | None = None
class KGQuestionRelationshipExtractionResult(BaseModel):
relationships: list[str]
class KGQuestionExtractionResult(BaseModel):
entities: list[str]
relationships: list[str]
terms: list[str]
time_filter: str | None
class KGExpandedGraphObjects(BaseModel):
entities: list[str]
relationships: list[str]
class KGSteps(BaseModel):
description: str
activities: list[str]
class KGEntityDocInfo(BaseModel):
doc_id: str | None
doc_semantic_id: str | None
doc_link: str | None
semantic_entity_name: str
semantic_linked_entity_name: str

View File

@@ -0,0 +1,274 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
)
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_structure
from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
from onyx.agents.agent_search.kb_search.models import (
KGQuestionRelationshipExtractionResult,
)
from onyx.agents.agent_search.kb_search.states import ERTExtractionUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.kg_configs import KG_ENTITY_EXTRACTION_TIMEOUT
from onyx.configs.kg_configs import KG_RELATIONSHIP_EXTRACTION_TIMEOUT
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.kg_temp_view import create_views
from onyx.db.kg_temp_view import get_user_view_names
from onyx.db.relationships import get_allowed_relationship_type_pairs
from onyx.kg.extractions.extraction_processing import get_entity_types_str
from onyx.kg.extractions.extraction_processing import get_relationship_types_str
from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def extract_ert(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ERTExtractionUpdate:
"""
LangGraph node to start the agentic search process.
"""
# recheck KG enablement at outset KG graph
if not config["metadata"]["config"].behavior.kg_config_settings.KG_ENABLED:
logger.error("KG approach is not enabled, the KG agent flow cannot run.")
raise ValueError("KG approach is not enabled, the KG agent flow cannot run.")
_KG_STEP_NR = 1
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
if graph_config.tooling.search_tool is None:
raise ValueError("Search tool is not set")
elif graph_config.tooling.search_tool.user is None:
raise ValueError("User is not set")
else:
user_email = graph_config.tooling.search_tool.user.email
user_name = user_email.split("@")[0] or "unknown"
# first four lines duplicates from generate_initial_answer
question = graph_config.inputs.prompt_builder.raw_user_query
today_date = datetime.now().strftime("%A, %Y-%m-%d")
all_entity_types = get_entity_types_str(active=True)
all_relationship_types = get_relationship_types_str(active=True)
# Stream structure of substeps out to the UI
stream_write_step_structure(writer)
# Now specify core activities in the step (step 1)
stream_write_step_activities(writer, _KG_STEP_NR)
# Create temporary views. TODO: move into parallel step, if ultimately materialized
allowed_docs_view_name, kg_relationships_view_name = get_user_view_names(user_email)
with get_session_with_current_tenant() as db_session:
create_views(
db_session,
user_email=user_email,
allowed_docs_view_name=allowed_docs_view_name,
kg_relationships_view_name=kg_relationships_view_name,
)
### get the entities, terms, and filters
query_extraction_pre_prompt = QUERY_ENTITY_EXTRACTION_PROMPT.format(
entity_types=all_entity_types,
relationship_types=all_relationship_types,
)
query_extraction_prompt = (
query_extraction_pre_prompt.replace("---content---", question)
.replace("---today_date---", today_date)
.replace("---user_name---", f"EMPLOYEE:{user_name}")
.replace("{{", "{")
.replace("}}", "}")
)
msg = [
HumanMessage(
content=query_extraction_prompt,
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_ENTITY_EXTRACTION_TIMEOUT,
primary_llm.invoke,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
cleaned_response = (
str(llm_response.content)
.replace("{{", "{")
.replace("}}", "}")
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
entity_extraction_result = (
KGQuestionEntityExtractionResult.model_validate_json(cleaned_response)
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
entity_extraction_result = KGQuestionEntityExtractionResult(
entities=[],
terms=[],
time_filter="",
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
entity_extraction_result = KGQuestionEntityExtractionResult(
entities=[],
terms=[],
time_filter="",
)
# remove the attribute filters from the entities to for the purpose of the relationship
entities_no_attributes = [
entity.split("--")[0] for entity in entity_extraction_result.entities
]
ert_entities_string = f"Entities: {entities_no_attributes}\n"
### get the relationships
# find the relationship types that match the extracted entity types
with get_session_with_current_tenant() as db_session:
allowed_relationship_pairs = get_allowed_relationship_type_pairs(
db_session, entity_extraction_result.entities
)
query_relationship_extraction_prompt = (
QUERY_RELATIONSHIP_EXTRACTION_PROMPT.replace("---question---", question)
.replace("---today_date---", today_date)
.replace(
"---relationship_type_options---",
" - " + "\n - ".join(allowed_relationship_pairs),
)
.replace("---identified_entities---", ert_entities_string)
.replace("---entity_types---", all_entity_types)
.replace("{{", "{")
.replace("}}", "}")
)
msg = [
HumanMessage(
content=query_relationship_extraction_prompt,
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_RELATIONSHIP_EXTRACTION_TIMEOUT,
primary_llm.invoke,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
cleaned_response = (
str(llm_response.content)
.replace("{{", "{")
.replace("}}", "}")
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = cleaned_response.replace("{{", '{"')
cleaned_response = cleaned_response.replace("}}", '"}')
try:
relationship_extraction_result = (
KGQuestionRelationshipExtractionResult.model_validate_json(
cleaned_response
)
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
relationships=[],
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
relationships=[],
)
## STEP 1
# Stream answer pieces out to the UI for Step 1
extracted_entity_string = " \n ".join(
[x.split("--")[0] for x in entity_extraction_result.entities]
)
extracted_relationship_string = " \n ".join(
relationship_extraction_result.relationships
)
step_answer = f"""Entities and relationships have been extracted from query - \n \
Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_string}"""
stream_write_step_answer_explicit(writer, step_nr=1, answer=step_answer)
# Finish Step 1
stream_close_step_answer(writer, _KG_STEP_NR)
return ERTExtractionUpdate(
entities_types_str=all_entity_types,
relationship_types_str=all_relationship_types,
extracted_entities_w_attributes=entity_extraction_result.entities,
extracted_entities_no_attributes=entities_no_attributes,
extracted_relationships=relationship_extraction_result.relationships,
extracted_terms=entity_extraction_result.terms,
time_filter=entity_extraction_result.time_filter,
kg_doc_temp_view_name=allowed_docs_view_name,
kg_rel_temp_view_name=kg_relationships_view_name,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="extract entities terms",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR,
step_answer=step_answer,
verified_reranked_documents=[],
)
],
)

View File

@@ -0,0 +1,311 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import (
create_minimal_connected_query_graph,
)
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
)
from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import YesNoEnum
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.kg_configs import KG_STRATEGY_GENERATION_TIMEOUT
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import get_document_id_for_entity
from onyx.kg.clustering.normalizations import normalize_entities
from onyx.kg.clustering.normalizations import normalize_entities_w_attributes_from_map
from onyx.kg.clustering.normalizations import normalize_relationships
from onyx.kg.clustering.normalizations import normalize_terms
from onyx.kg.utils.formatting_utils import split_relationship_id
from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def _articulate_normalizations(
entity_normalization_map: dict[str, str],
relationship_normalization_map: dict[str, str],
) -> str:
remark_list: list[str] = []
if entity_normalization_map:
remark_list.append("\n Entities:")
for extracted_entity, normalized_entity in entity_normalization_map.items():
remark_list.append(f" - {extracted_entity} -> {normalized_entity}")
if relationship_normalization_map:
remark_list.append(" \n Relationships:")
for (
extracted_relationship,
normalized_relationship,
) in relationship_normalization_map.items():
remark_list.append(
f" - {extracted_relationship} -> {normalized_relationship}"
)
return " \n ".join(remark_list)
def _get_fully_connected_entities(
entities: list[str], relationships: list[str]
) -> list[str]:
"""
Analyze the connectedness of the entities and relationships.
"""
# Build a dictionary to track connections for each entity
entity_connections: dict[str, set[str]] = {entity: set() for entity in entities}
# Parse relationships to build connection graph
for relationship in relationships:
# Split relationship into parts. Test for proper formatting just in case.
# Should never be an error though at this point.
parts = split_relationship_id(relationship)
if len(parts) != 3:
raise ValueError(f"Invalid relationship: {relationship}")
entity1 = parts[0]
entity2 = parts[2]
# Add bidirectional connections
if entity1 in entity_connections:
entity_connections[entity1].add(entity2)
if entity2 in entity_connections:
entity_connections[entity2].add(entity1)
# Find entities connected to all others
fully_connected_entities = []
all_entities = set(entities)
for entity, connections in entity_connections.items():
# Check if this entity is connected to all other entities
if connections == all_entities - {entity}:
fully_connected_entities.append(entity)
return fully_connected_entities
def _check_for_single_doc(
normalized_entities: list[str],
raw_entities: list[str],
normalized_relationship_strings: list[str],
raw_relationships: list[str],
normalized_time_filter: str | None,
) -> str | None:
"""
Check if the query is for a single document, like 'Summarize ticket ENG-2243K'.
None is returned if the query is not for a single document.
"""
if (
len(normalized_entities) == 1
and len(raw_entities) == 1
and len(normalized_relationship_strings) == 0
and len(raw_relationships) == 0
and normalized_time_filter is None
):
with get_session_with_current_tenant() as db_session:
single_doc_id = get_document_id_for_entity(
db_session, normalized_entities[0]
)
else:
single_doc_id = None
return single_doc_id
def analyze(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> AnalysisUpdate:
"""
LangGraph node to start the agentic search process.
"""
_KG_STEP_NR = 2
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
entities = (
state.extracted_entities_no_attributes
) # attribute knowledge is not required for this step
relationships = state.extracted_relationships
terms = state.extracted_terms
time_filter = state.time_filter
## STEP 2 - stream out goals
stream_write_step_activities(writer, _KG_STEP_NR)
# Continue with node
normalized_entities = normalize_entities(
entities, allowed_docs_temp_view_name=state.kg_doc_temp_view_name
)
query_graph_entities_w_attributes = normalize_entities_w_attributes_from_map(
state.extracted_entities_w_attributes,
normalized_entities.entity_normalization_map,
)
normalized_relationships = normalize_relationships(
relationships, normalized_entities.entity_normalization_map
)
normalized_terms = normalize_terms(terms)
normalized_time_filter = time_filter
# If single-doc inquiry, send to single-doc processing directly
single_doc_id = _check_for_single_doc(
normalized_entities=normalized_entities.entities,
raw_entities=entities,
normalized_relationship_strings=normalized_relationships.relationships,
raw_relationships=relationships,
normalized_time_filter=normalized_time_filter,
)
# Expand the entities and relationships to make sure that entities are connected
graph_expansion = create_minimal_connected_query_graph(
normalized_entities.entities,
normalized_relationships.relationships,
max_depth=2,
)
query_graph_entities = graph_expansion.entities
query_graph_relationships = graph_expansion.relationships
# Evaluate whether a search needs to be done after identifying all entities and relationships
strategy_generation_prompt = (
STRATEGY_GENERATION_PROMPT.replace(
"---entities---", "\n".join(query_graph_entities)
)
.replace("---relationships---", "\n".join(query_graph_relationships))
.replace("---possible_entities---", state.entities_types_str)
.replace("---possible_relationships---", state.relationship_types_str)
.replace("---question---", question)
)
msg = [
HumanMessage(
content=strategy_generation_prompt,
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_STRATEGY_GENERATION_TIMEOUT,
# fast_llm.invoke,
primary_llm.invoke,
prompt=msg,
timeout_override=5,
max_tokens=100,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
approach_extraction_result = KGAnswerApproach.model_validate_json(
cleaned_response
)
search_type = approach_extraction_result.search_type
search_strategy = approach_extraction_result.search_strategy
output_format = approach_extraction_result.format
broken_down_question = approach_extraction_result.broken_down_question
divide_and_conquer = approach_extraction_result.divide_and_conquer
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
search_type = KGSearchType.SEARCH
search_strategy = KGAnswerStrategy.DEEP
output_format = KGAnswerFormat.TEXT
broken_down_question = None
divide_and_conquer = YesNoEnum.NO
if search_strategy is None or output_format is None:
raise ValueError(f"Invalid strategy: {cleaned_response}")
except Exception as e:
logger.error(f"Error in strategy generation: {e}")
raise e
# Stream out relevant results
if single_doc_id:
search_strategy = (
KGAnswerStrategy.DEEP
) # if a single doc is identified, we will want to look at the details.
step_answer = f"Strategy and format have been extracted from query. Strategy: {search_strategy.value}, \
Format: {output_format.value}, Broken down question: {broken_down_question}"
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
stream_close_step_answer(writer, _KG_STEP_NR)
# End node
return AnalysisUpdate(
normalized_core_entities=normalized_entities.entities,
normalized_core_relationships=normalized_relationships.relationships,
entity_normalization_map=normalized_entities.entity_normalization_map,
relationship_normalization_map=normalized_relationships.relationship_normalization_map,
query_graph_entities_no_attributes=query_graph_entities,
query_graph_entities_w_attributes=query_graph_entities_w_attributes,
query_graph_relationships=query_graph_relationships,
normalized_terms=normalized_terms.terms,
normalized_time_filter=normalized_time_filter,
strategy=search_strategy,
broken_down_question=broken_down_question,
output_format=output_format,
divide_and_conquer=divide_and_conquer,
single_doc_id=single_doc_id,
search_type=search_type,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="analyze",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR,
step_answer=step_answer,
verified_reranked_documents=[],
)
],
remarks=[
_articulate_normalizations(
entity_normalization_map=normalized_entities.entity_normalization_map,
relationship_normalization_map=normalized_relationships.relationship_normalization_map,
)
],
)

View File

@@ -0,0 +1,415 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy import text
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
)
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import SQLSimpleGenerationUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from shared_configs.contextvars import get_current_tenant_id
from onyx.configs.kg_configs import KG_MAX_DEEP_SEARCH_RESULTS
from onyx.configs.kg_configs import KG_SQL_GENERATION_TIMEOUT
from onyx.db.engine import get_db_readonly_user_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.kg_temp_view import drop_views
from onyx.llm.interfaces import LLM
from onyx.prompts.kg_prompts import SIMPLE_SQL_CORRECTION_PROMPT
from onyx.prompts.kg_prompts import SIMPLE_SQL_PROMPT
from onyx.prompts.kg_prompts import SOURCE_DETECTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def _drop_temp_views(
allowed_docs_view_name: str, kg_relationships_view_name: str
) -> None:
with get_session_with_current_tenant() as db_session:
drop_views(
db_session,
allowed_docs_view_name=allowed_docs_view_name,
kg_relationships_view_name=kg_relationships_view_name,
)
def _build_entity_explanation_str(entity_normalization_map: dict[str, str]) -> str:
"""
Build a string of contextualized entities to avoid the model not being aware of
what eg ACCOUNT::SF_8254Hs means as a normalized entity
"""
entity_explanation_components = []
for entity, normalized_entity in entity_normalization_map.items():
entity_explanation_components.append(f" - {entity} -> {normalized_entity}")
return "\n".join(entity_explanation_components)
def _sql_is_aggregate_query(sql_statement: str) -> bool:
return any(
agg_func in sql_statement.upper()
for agg_func in ["COUNT(", "MAX(", "MIN(", "AVG(", "SUM("]
)
def _get_source_documents(
sql_statement: str,
llm: LLM,
allowed_docs_view_name: str,
kg_relationships_view_name: str,
) -> str | None:
"""
Generate SQL to retrieve source documents based on the input sql statement.
"""
source_detection_prompt = SOURCE_DETECTION_PROMPT.replace(
"---original_sql_statement---", sql_statement
)
msg = [
HumanMessage(
content=source_detection_prompt,
)
]
cleaned_response: str | None = None
try:
llm_response = run_with_timeout(
KG_SQL_GENERATION_TIMEOUT,
llm.invoke,
prompt=msg,
timeout_override=25,
max_tokens=1200,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
sql_statement = cleaned_response.split("<sql>")[1].strip()
sql_statement = sql_statement.split("</sql>")[0].strip()
except Exception as e:
if cleaned_response is not None:
logger.error(
f"Could not generate source documents SQL: {e}. Original model response: {cleaned_response}"
)
else:
logger.error(f"Could not generate source documents SQL: {e}")
return None
return sql_statement
def generate_simple_sql(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> SQLSimpleGenerationUpdate:
"""
LangGraph node to start the agentic search process.
"""
_KG_STEP_NR = 3
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
entities_types_str = state.entities_types_str
relationship_types_str = state.relationship_types_str
single_doc_id = state.single_doc_id
if state.kg_doc_temp_view_name is None:
raise ValueError("kg_doc_temp_view_name is not set")
if state.kg_rel_temp_view_name is None:
raise ValueError("kg_rel_temp_view_name is not set")
## STEP 3 - articulate goals
stream_write_step_activities(writer, _KG_STEP_NR)
if graph_config.tooling.search_tool is None:
raise ValueError("Search tool is not set")
elif graph_config.tooling.search_tool.user is None:
raise ValueError("User is not set")
else:
user_email = graph_config.tooling.search_tool.user.email
user_name = user_email.split("@")[0]
if state.search_type == KGSearchType.SQL and single_doc_id:
# If single doc id already identified, we do not need to go through the KG
# query cycle, saving a lot of time.
main_sql_statement: str | None = None
query_results: list[dict[str, Any]] | None = None
source_documents_sql: str | None = None
source_document_results: list[str] | None = [single_doc_id]
reasoning: str | None = (
f"A KG query was not required as the source document was already identified: {single_doc_id}"
)
step_answer = f"Source document already identified: {single_doc_id}"
elif state.search_type == KGSearchType.SEARCH:
# If we do a filtered search, then we do not need to go through the SQL
# generation process.
main_sql_statement = None
query_results = None
source_documents_sql = None
source_document_results = None
reasoning = "A KG query was not required as we will use a filtered search."
step_answer = "Filtered search will be used."
else:
# If no single doc id already identified, we need to go through the KG
# query cycle, including generating the SQL for the answer and the sources
# Build prompt
# First, create string of contextualized entities to avoid the model not
# being aware of what eg ACCOUNT::SF_8254Hs means as a normalized entity
entity_explanation_str = _build_entity_explanation_str(
state.entity_normalization_map
)
current_tenant = get_current_tenant_id()
current_tenant_view_name = f'"{current_tenant}".{state.kg_doc_temp_view_name}'
current_tenant_rel_view_name = f'"{current_tenant}".{state.kg_rel_temp_view_name}'
simple_sql_prompt = (
SIMPLE_SQL_PROMPT.replace("---entity_types---", entities_types_str)
.replace("---relationship_types---", relationship_types_str)
.replace("---question---", question)
.replace("---entity_explanation_string---", entity_explanation_str)
.replace(
"---query_entities_with_attributes---",
"\n".join(state.query_graph_entities_w_attributes),
)
.replace(
"---query_relationships---", "\n".join(state.query_graph_relationships)
)
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
.replace("---user_name---", f"EMPLOYEE:{user_name}")
)
# prepare SQL query generation
msg = [
HumanMessage(
content=simple_sql_prompt,
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_SQL_GENERATION_TIMEOUT,
primary_llm.invoke,
prompt=msg,
timeout_override=25,
max_tokens=1500,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
sql_statement = (
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
)
sql_statement = sql_statement.split(";")[0].strip() + ";"
sql_statement = sql_statement.replace("sql", "").strip()
sql_statement = sql_statement.replace(
"kg_relationship", current_tenant_rel_view_name
)
reasoning = (
cleaned_response.split("<reasoning>")[1]
.strip()
.split("</reasoning>")[0]
)
except Exception as e:
logger.error(f"Error in strategy generation: {e}")
_drop_temp_views(
allowed_docs_view_name=current_tenant_view_name,
kg_relationships_view_name=current_tenant_rel_view_name,
)
raise e
logger.debug(f"A3 - sql_statement: {sql_statement}")
# Correction if needed:
correction_prompt = SIMPLE_SQL_CORRECTION_PROMPT.replace(
"---draft_sql---", sql_statement
)
msg = [
HumanMessage(
content=correction_prompt,
)
]
try:
llm_response = run_with_timeout(
KG_SQL_GENERATION_TIMEOUT,
primary_llm.invoke,
prompt=msg,
timeout_override=25,
max_tokens=1500,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
sql_statement = (
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
)
sql_statement = sql_statement.split(";")[0].strip() + ";"
sql_statement = sql_statement.replace("sql", "").strip()
except Exception as e:
logger.error(
f"Error in generating the sql correction: {e}. Original model response: {cleaned_response}"
)
_drop_temp_views(
allowed_docs_view_name=current_tenant_view_name,
kg_relationships_view_name=current_tenant_rel_view_name,
)
raise e
logger.debug(f"A3 - sql_statement after correction: {sql_statement}")
# Get SQL for source documents
source_documents_sql = _get_source_documents(
sql_statement,
llm=primary_llm,
allowed_docs_view_name=current_tenant_view_name,
kg_relationships_view_name=current_tenant_rel_view_name,
)
logger.info(f"A3 source_documents_sql: {source_documents_sql}")
scalar_result = None
query_results = None
with get_db_readonly_user_session_with_current_tenant() as db_session:
try:
result = db_session.execute(text(sql_statement))
# Handle scalar results (like COUNT)
if sql_statement.upper().startswith("SELECT COUNT"):
scalar_result = result.scalar()
query_results = (
[{"count": int(scalar_result)}]
if scalar_result is not None
else []
)
else:
# Handle regular row results
rows = result.fetchall()
query_results = [dict(row._mapping) for row in rows]
except Exception as e:
logger.error(f"Error executing SQL query: {e}")
raise e
source_document_results = None
if source_documents_sql is not None and source_documents_sql != sql_statement:
with get_db_readonly_user_session_with_current_tenant() as db_session:
try:
result = db_session.execute(text(source_documents_sql))
rows = result.fetchall()
query_source_document_results = [dict(row._mapping) for row in rows]
source_document_results = [
source_document_result["source_document"]
for source_document_result in query_source_document_results
]
except Exception as e:
# No stopping here, the individualized SQL query is not mandatory
logger.error(f"Error executing Individualized SQL query: {e}")
# individualized_query_results = None
else:
source_document_results = None
_drop_temp_views(
allowed_docs_view_name=state.kg_doc_temp_view_name,
kg_relationships_view_name=state.kg_rel_temp_view_name,
)
logger.info(f"A3 - Number of query_results: {len(query_results)}")
# Stream out reasoning and SQL query
step_answer = f"Reasoning: {reasoning} \n \n SQL Query: {sql_statement}"
main_sql_statement = sql_statement
if reasoning:
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=reasoning)
if main_sql_statement:
stream_write_step_answer_explicit(
writer,
step_nr=_KG_STEP_NR,
answer=f" \n Generated SQL: {main_sql_statement}",
)
stream_close_step_answer(writer, _KG_STEP_NR)
# Update path if too many results are retrieved
if query_results and len(query_results) > KG_MAX_DEEP_SEARCH_RESULTS:
updated_strategy = KGAnswerStrategy.SIMPLE
else:
updated_strategy = None
return SQLSimpleGenerationUpdate(
sql_query=main_sql_statement,
sql_query_results=query_results,
individualized_sql_query=None,
individualized_query_results=None,
source_documents_sql=source_documents_sql,
source_document_results=source_document_results or [],
updated_strategy=updated_strategy,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate simple sql",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR,
step_answer=step_answer,
verified_reranked_documents=[],
)
],
)

View File

@@ -0,0 +1,186 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
from onyx.agents.agent_search.kb_search.states import KGFilterConstructionResults
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.kg_configs import KG_FILTER_CONSTRUCTION_TIMEOUT
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entity_type import get_entity_types_with_grounded_source_name
from onyx.kg.utils.formatting_utils import make_entity_id
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def construct_deep_search_filters(
state: MainState, config: RunnableConfig, writer: StreamWriter
) -> DeepSearchFilterUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
entities_types_str = state.entities_types_str
entities = state.query_graph_entities_no_attributes
relationships = state.query_graph_relationships
simple_sql_query = state.sql_query
simple_sql_results = state.sql_query_results
source_document_results = state.source_document_results
if simple_sql_results:
simple_sql_results_str = "\n".join([str(x) for x in simple_sql_results])
else:
simple_sql_results_str = "(no SQL results generated)"
if source_document_results:
source_document_results_str = "\n".join(
[str(x) for x in source_document_results]
)
else:
source_document_results_str = "(no source document results generated)"
logger.info(
f"B1 - characters in source_document_results_str: len{source_document_results_str}"
)
search_filter_construction_prompt = (
SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
"---entity_type_descriptions---",
entities_types_str,
)
.replace(
"---entity_filters---",
"\n".join(entities),
)
.replace(
"---relationship_filters---",
"\n".join(relationships),
)
.replace(
"---sql_query---",
simple_sql_query or "(no SQL generated)",
)
.replace(
"---sql_results---",
simple_sql_results_str or "(no SQL results generated)",
)
.replace(
"---source_document_results---",
source_document_results_str or "(no source document results generated)",
)
.replace(
"---question---",
question,
)
)
msg = [
HumanMessage(
content=search_filter_construction_prompt,
)
]
llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_FILTER_CONSTRUCTION_TIMEOUT,
llm.invoke,
prompt=msg,
timeout_override=15,
max_tokens=1400,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
if last_bracket == -1 or first_bracket == -1:
raise ValueError("No valid JSON found in LLM response - no brackets found")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = cleaned_response.replace("{{", '{"')
cleaned_response = cleaned_response.replace("}}", '"}')
try:
filter_results = KGFilterConstructionResults.model_validate_json(
cleaned_response
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
filter_results = KGFilterConstructionResults(
global_entity_filters=[],
global_relationship_filters=[],
local_entity_filters=[],
source_document_filters=[],
structure=[],
)
logger.info(f"B1 - filter_results: {filter_results}")
except Exception as e:
logger.error(f"Error in construct_deep_search_filters: {e}")
filter_results = KGFilterConstructionResults(
global_entity_filters=[],
global_relationship_filters=[],
local_entity_filters=[],
source_document_filters=[],
structure=[],
)
div_con_structure = filter_results.structure
logger.info(f"div_con_structure: {div_con_structure}")
with get_session_with_current_tenant() as db_session:
double_grounded_entity_types = get_entity_types_with_grounded_source_name(
db_session
)
source_division = False
if div_con_structure:
for entity_type in double_grounded_entity_types:
if entity_type.grounded_source_name.lower() in div_con_structure[0].lower():
source_division = True
break
return DeepSearchFilterUpdate(
vespa_filter_results=filter_results,
div_con_entities=div_con_structure,
source_division=source_division,
global_entity_filters=[
make_entity_id(global_filter, "*")
for global_filter in filter_results.global_entity_filters
],
global_relationship_filters=filter_results.global_relationship_filters,
local_entity_filters=filter_results.local_entity_filters,
source_filters=filter_results.source_document_filters,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="construct deep search filters",
node_start_time=node_start_time,
)
],
step_results=[],
)

View File

@@ -0,0 +1,178 @@
import copy
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.kb_search.graph_utils import (
get_doc_information_for_entity,
)
from onyx.agents.agent_search.kb_search.graph_utils import write_custom_event
from onyx.agents.agent_search.kb_search.ops import research
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
from onyx.agents.agent_search.kb_search.states import ResearchObjectUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.chat.models import LlmDoc
from onyx.chat.models import SubQueryPiece
from onyx.configs.kg_configs import KG_OBJECT_SOURCE_RESEARCH_TIMEOUT
from onyx.context.search.models import InferenceSection
from onyx.prompts.kg_prompts import KG_OBJECT_SOURCE_RESEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def process_individual_deep_search(
state: ResearchObjectInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ResearchObjectUpdate:
"""
LangGraph node to start the agentic search process.
"""
_KG_STEP_NR = 4
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = state.broken_down_question
source_division = state.source_division
source_filters = state.source_entity_filters
object = state.entity.replace("::", ":: ").lower()
object_id = object.split("::")[1].strip()
if not search_tool:
raise ValueError("search_tool is not provided")
research_nr = state.research_nr
extended_question = f"{question} in regards to {object}"
if source_division:
extended_question = question
raw_kg_entity_filters = copy.deepcopy(
list(set((state.vespa_filter_results.global_entity_filters + [state.entity])))
)
kg_entity_filters = []
for raw_kg_entity_filter in raw_kg_entity_filters:
if "::" not in raw_kg_entity_filter:
raw_kg_entity_filter += "::*"
kg_entity_filters.append(raw_kg_entity_filter)
kg_relationship_filters = copy.deepcopy(
state.vespa_filter_results.global_relationship_filters
)
# if this is a single-object analysis and object is a source,
# drop the other filters as they already were included
# in the KG query that led to this analysis
if source_division and source_filters:
for source_filter in source_filters:
if object_id.lower() == source_filter.lower():
source_filters = [source_filter]
kg_relationship_filters = []
kg_entity_filters = []
break
logger.info("Research for object: " + object)
logger.info(f"kg_entity_filters: {kg_entity_filters}")
logger.info(f"kg_relationship_filters: {kg_relationship_filters}")
# Step 4 - stream out the research query
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query=f"{get_doc_information_for_entity(object).semantic_entity_name}",
level=0,
level_question_num=_KG_STEP_NR,
query_id=research_nr + 1,
),
writer,
)
retrieved_docs = research(
question=extended_question,
kg_entities=kg_entity_filters,
kg_relationships=kg_relationship_filters,
kg_sources=source_filters,
search_tool=search_tool,
)
document_texts_list = []
for doc_num, retrieved_doc in enumerate(retrieved_docs):
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# Built prompt
kg_object_source_research_prompt = KG_OBJECT_SOURCE_RESEARCH_PROMPT.format(
question=extended_question,
document_text=document_texts,
)
# Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=kg_object_source_research_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
primary_llm.invoke,
prompt=msg,
timeout_override=KG_OBJECT_SOURCE_RESEARCH_TIMEOUT,
max_tokens=300,
)
object_research_results = str(llm_response.content).replace("```json\n", "")
except Exception as e:
raise ValueError(f"Error in research_object_source: {e}")
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
return ResearchObjectUpdate(
research_object_results=[
{
"object": object.replace("::", ":: ").capitalize(),
"results": object_research_results,
}
],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="process individual deep search",
node_start_time=node_start_time,
)
],
step_results=[],
)

View File

@@ -0,0 +1,177 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
)
from onyx.agents.agent_search.kb_search.graph_utils import write_custom_event
from onyx.agents.agent_search.kb_search.ops import research
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.chat.models import SubQueryPiece
from onyx.configs.kg_configs import KG_FILTERED_SEARCH_TIMEOUT
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
from onyx.context.search.models import InferenceSection
from onyx.prompts.kg_prompts import KG_SEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def filtered_search(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ConsolidatedResearchUpdate:
"""
LangGraph node to do a filtered search.
"""
_KG_STEP_NR = 4
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.prompt_builder.raw_user_query
if not search_tool:
raise ValueError("search_tool is not provided")
if not state.vespa_filter_results:
raise ValueError("vespa_filter_results is not provided")
raw_kg_entity_filters = list(
set((state.vespa_filter_results.global_entity_filters))
)
kg_entity_filters = []
for raw_kg_entity_filter in raw_kg_entity_filters:
if "::" not in raw_kg_entity_filter:
raw_kg_entity_filter += "::*"
kg_entity_filters.append(raw_kg_entity_filter)
kg_relationship_filters = state.vespa_filter_results.global_relationship_filters
logger.info("Starting filtered search")
logger.debug(f"kg_entity_filters: {kg_entity_filters}")
logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
# Step 4 - stream out the research query
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query="Conduct a filtered search",
level=0,
level_question_num=_KG_STEP_NR,
query_id=1,
),
writer,
)
retrieved_docs = cast(
list[InferenceSection],
research(
question=question,
kg_entities=kg_entity_filters,
kg_relationships=kg_relationship_filters,
kg_sources=None,
search_tool=search_tool,
inference_sections_only=True,
),
)
answer_generation_documents = get_answer_generation_documents(
relevant_docs=retrieved_docs,
context_documents=retrieved_docs,
original_question_docs=retrieved_docs,
max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
)
document_texts_list = []
for doc_num, retrieved_doc in enumerate(
answer_generation_documents.context_documents
):
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# Built prompt
datetime.now().strftime("%A, %Y-%m-%d")
kg_object_source_research_prompt = KG_SEARCH_PROMPT.format(
question=question,
document_text=document_texts,
)
# Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=kg_object_source_research_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
llm = primary_llm
# Grader
try:
llm_response = run_with_timeout(
KG_FILTERED_SEARCH_TIMEOUT,
llm.invoke,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
filtered_search_answer = str(llm_response.content).replace("```json\n", "")
except Exception as e:
raise ValueError(f"Error in filtered_search: {e}")
step_answer = "Filtered search is complete."
stream_write_step_answer_explicit(
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
)
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
return ConsolidatedResearchUpdate(
consolidated_research_object_results_str=filtered_search_answer,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="filtered search",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR,
step_answer=step_answer,
verified_reranked_documents=retrieved_docs,
)
],
)

View File

@@ -0,0 +1,64 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
)
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def consolidate_individual_deep_search(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ConsolidatedResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
_KG_STEP_NR = 4
node_start_time = datetime.now()
research_object_results = state.research_object_results
consolidated_research_object_results_str = "\n".join(
[f"{x['object']}: {x['results']}" for x in research_object_results]
)
consolidated_research_object_results_str = rename_entities_in_answer(
consolidated_research_object_results_str
)
step_answer = "All research is complete. Consolidating results..."
stream_write_step_answer_explicit(
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
)
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
return ConsolidatedResearchUpdate(
consolidated_research_object_results_str=consolidated_research_object_results_str,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="consolidate individual deep search",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR, step_answer=step_answer
)
],
)

View File

@@ -0,0 +1,122 @@
from datetime import datetime
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
)
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import SubQueryPiece
from onyx.db.document import get_base_llm_doc_information
from onyx.db.engine import get_session_with_current_tenant
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_formated_source_reference_results(
source_document_results: list[str] | None,
) -> str | None:
"""
Generate reference results from the query results data string.
"""
if source_document_results is None:
return None
# get all entities that correspond to an Onyx document
document_ids = source_document_results
with get_session_with_current_tenant() as session:
llm_doc_information_results = get_base_llm_doc_information(
session, document_ids
)
if len(llm_doc_information_results) == 0:
return ""
return (
f"\n \n Here are {len(llm_doc_information_results)} supporting documents or examples: \n \n "
+ " \n \n ".join(llm_doc_information_results)
)
def process_kg_only_answers(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ResultsDataUpdate:
"""
LangGraph node to start the agentic search process.
"""
_KG_STEP_NR = 4
node_start_time = datetime.now()
query_results = state.sql_query_results
source_document_results = state.source_document_results
# we use this stream write explicitly
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query="Formatted References",
level=0,
level_question_num=_KG_STEP_NR,
query_id=1,
),
writer,
)
query_results_list = []
if query_results:
for query_result in query_results:
query_results_list.append(
str(query_result).replace("::", ":: ").capitalize()
)
else:
raise ValueError("No query results were found")
query_results_data_str = "\n".join(query_results_list)
source_reference_result_str = _get_formated_source_reference_results(
source_document_results
)
## STEP 4 - same components as Step 1
step_answer = (
"No further research is needed, the answer is derived from the knowledge graph."
)
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
stream_close_step_answer(writer, _KG_STEP_NR)
return ResultsDataUpdate(
query_results_data_str=query_results_data_str,
individualized_query_results_data_str="",
reference_results_str=source_reference_result_str,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="kg query results data processing",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR, step_answer=step_answer
)
],
)

View File

@@ -0,0 +1,282 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_close_main_answer,
)
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_close_steps
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_main_answer_token,
)
from onyx.agents.agent_search.kb_search.ops import research
from onyx.agents.agent_search.kb_search.states import MainOutput
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import ExtendedToolResponse
from onyx.configs.kg_configs import KG_ANSWER_GENERATION_TIMEOUT
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
from onyx.context.search.enums import SearchType
from onyx.context.search.models import InferenceSection
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_EXAMPLES_PROMPT
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT
from onyx.tools.tool_implementations.search.search_tool import IndexFilters
from onyx.tools.tool_implementations.search.search_tool import SearchQueryInfo
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def _stream_augmentations(
llm_tokenizer: BaseTokenizer, streaming_text: str, writer: StreamWriter
) -> None:
# Tokenize and stream the reference results
tokens = llm_tokenizer.tokenize(streaming_text)
for token in tokens:
stream_write_main_answer_token(writer, token)
def generate_answer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> MainOutput:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.prompt_builder.raw_user_query
search_tool = graph_config.tooling.search_tool
if search_tool is None:
raise ValueError("Search tool is not set")
# Close out previous streams of steps
# DECLARE STEPS DONE
stream_write_close_steps(writer)
## MAIN ANSWER
# identify whether documents have already been retrieved
retrieved_docs: list[InferenceSection] = []
for step_result in state.step_results:
retrieved_docs += step_result.verified_reranked_documents
# if still needed, get a search done and send the results to the UI
if not retrieved_docs and state.source_document_results:
assert graph_config.tooling.search_tool is not None
retrieved_docs = cast(
list[InferenceSection],
research(
question=question,
kg_entities=[],
kg_relationships=[],
kg_sources=state.source_document_results[
:KG_RESEARCH_NUM_RETRIEVED_DOCS
],
search_tool=graph_config.tooling.search_tool,
kg_chunk_id_zero_only=True,
inference_sections_only=True,
),
)
answer_generation_documents = get_answer_generation_documents(
relevant_docs=retrieved_docs,
context_documents=retrieved_docs,
original_question_docs=retrieved_docs,
max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
)
relevance_list = relevance_from_docs(
answer_generation_documents.streaming_documents
)
assert graph_config.tooling.search_tool is not None
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=SearchQueryInfo(
predicted_search=SearchType.KEYWORD,
# acl here is empty, because the searach alrady happened and
# we are streaming out the results.
final_filters=IndexFilters(access_control_list=[]),
recency_bias_multiplier=1.0,
),
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=0,
level_question_num=0, # 0, 0 is the base question
),
writer,
)
# continue with the answer generation
output_format = (
state.output_format.value
if state.output_format
else "<you be the judge how to best present the data>"
)
# if deep path was taken:
consolidated_research_object_results_str = (
state.consolidated_research_object_results_str
)
# reference_results_str = (
# state.reference_results_str
# ) # will not be part of LLM. Manually added to the answer
# if simple path was taken:
introductory_answer = state.query_results_data_str # from simple answer path only
if consolidated_research_object_results_str:
research_results = consolidated_research_object_results_str
else:
research_results = ""
if introductory_answer:
output_format_prompt = (
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
.replace(
"---introductory_answer---",
rename_entities_in_answer(introductory_answer),
)
.replace("---output_format---", str(output_format) if output_format else "")
)
elif research_results and consolidated_research_object_results_str:
output_format_prompt = (
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
.replace(
"---introductory_answer---",
rename_entities_in_answer(consolidated_research_object_results_str),
)
.replace("---output_format---", str(output_format) if output_format else "")
)
elif research_results and not consolidated_research_object_results_str:
output_format_prompt = (
OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT.replace("---question---", question)
.replace("---output_format---", str(output_format) if output_format else "")
.replace(
"---research_results---", rename_entities_in_answer(research_results)
)
)
elif consolidated_research_object_results_str:
output_format_prompt = (
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
.replace("---output_format---", str(output_format) if output_format else "")
.replace(
"---research_results---", rename_entities_in_answer(research_results)
)
)
else:
raise ValueError("No research results or introductory answer provided")
msg = [
HumanMessage(
content=output_format_prompt,
)
]
fast_llm = graph_config.tooling.fast_llm
dispatch_timings: list[float] = []
response: list[str] = []
def stream_answer() -> list[str]:
# Get the LLM's tokenizer
llm_tokenizer = get_tokenizer(
model_name=fast_llm.config.model_name,
provider_type=fast_llm.config.model_provider,
)
for message in fast_llm.stream(
prompt=msg,
timeout_override=30,
max_tokens=1000,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
# Tokenize the content using the LLM's tokenizer
tokens = llm_tokenizer.tokenize(content)
for token in tokens:
start_stream_token = datetime.now()
stream_write_main_answer_token(
writer, token, level=0, level_question_num=0
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(token)
return response
try:
response = run_with_timeout(
KG_ANSWER_GENERATION_TIMEOUT,
stream_answer,
)
# llm_tokenizer = get_tokenizer(
# model_name=fast_llm.config.model_name,
# provider_type=fast_llm.config.model_provider,
# )
# TODO: the fake streaming should happen in friont-end. Revisit and then
# simply stream out here the full text in one.
# if reference_results_str:
# # Get the LLM's tokenizer
# _stream_augmentations(llm_tokenizer, reference_results_str, writer)
# if state.remarks:
# _stream_augmentations(llm_tokenizer, "Comments: \n " + "\n".join(state.remarks), writer)
except Exception as e:
raise ValueError(f"Could not generate the answer. Error {e}")
stream_write_close_main_answer(writer)
return MainOutput(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="query completed",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,58 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.states import MainOutput
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.chat import log_agent_sub_question_results
from onyx.utils.logger import setup_logger
logger = setup_logger()
def log_data(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> MainOutput:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
if search_tool is None:
raise ValueError("Search tool is not set")
# commit original db_session
query_db_session = graph_config.persistence.db_session
query_db_session.commit()
chat_session_id = graph_config.persistence.chat_session_id
primary_message_id = graph_config.persistence.message_id
sub_question_answer_results = state.step_results
log_agent_sub_question_results(
db_session=query_db_session,
chat_session_id=chat_session_id,
primary_message_id=primary_message_id,
sub_question_answer_results=sub_question_answer_results,
)
return MainOutput(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="query completed",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,65 @@
from datetime import datetime
from typing import cast
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
def research(
question: str,
search_tool: SearchTool,
document_sources: list[DocumentSource] | None = None,
time_cutoff: datetime | None = None,
kg_entities: list[str] | None = None,
kg_relationships: list[str] | None = None,
kg_terms: list[str] | None = None,
kg_sources: list[str] | None = None,
kg_chunk_id_zero_only: bool = False,
inference_sections_only: bool = False,
) -> list[LlmDoc] | list[InferenceSection]:
# new db session to avoid concurrency issues
callback_container: list[list[InferenceSection]] = []
retrieved_docs: list[LlmDoc] | list[InferenceSection] = []
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=question,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
document_sources=document_sources,
time_cutoff=time_cutoff,
kg_entities=kg_entities,
kg_relationships=kg_relationships,
kg_terms=kg_terms,
kg_sources=kg_sources,
kg_chunk_id_zero_only=kg_chunk_id_zero_only,
),
):
if (
inference_sections_only
and tool_response.id == "search_response_summary"
):
retrieved_docs = tool_response.response.top_sections[
:KG_RESEARCH_NUM_RETRIEVED_DOCS
]
retrieved_docs = cast(list[InferenceSection], retrieved_docs)
break
# get retrieved docs to send to the rest of the graph
elif tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
retrieved_docs = cast(list[LlmDoc], tool_response.response)[
:KG_RESEARCH_NUM_RETRIEVED_DOCS
]
break
return retrieved_docs

View File

@@ -0,0 +1,169 @@
from enum import Enum
from operator import add
from typing import Annotated
from typing import Any
from typing import Dict
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import SubQuestionAnswerResults
from onyx.context.search.models import InferenceSection
### States ###
class StepResults(BaseModel):
question: str
question_id: str
answer: str
sub_query_retrieval_results: list[QueryRetrievalResult]
verified_reranked_documents: list[InferenceSection]
context_documents: list[InferenceSection]
cited_documents: list[InferenceSection]
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
step_results: Annotated[list[SubQuestionAnswerResults], add]
remarks: Annotated[list[str], add] = []
class KGFilterConstructionResults(BaseModel):
global_entity_filters: list[str]
global_relationship_filters: list[str]
local_entity_filters: list[list[str]]
source_document_filters: list[str]
structure: list[str]
class KGSearchType(Enum):
SEARCH = "SEARCH"
SQL = "SQL"
class KGAnswerStrategy(Enum):
DEEP = "DEEP"
SIMPLE = "SIMPLE"
class KGAnswerFormat(Enum):
LIST = "LIST"
TEXT = "TEXT"
class YesNoEnum(str, Enum):
YES = "yes"
NO = "no"
class AnalysisUpdate(LoggerUpdate):
normalized_core_entities: list[str] = []
normalized_core_relationships: list[str] = []
entity_normalization_map: dict[str, str] = {}
relationship_normalization_map: dict[str, str] = {}
query_graph_entities_no_attributes: list[str] = []
query_graph_entities_w_attributes: list[str] = []
query_graph_relationships: list[str] = []
normalized_terms: list[str] = []
normalized_time_filter: str | None = None
strategy: KGAnswerStrategy | None = None
output_format: KGAnswerFormat | None = None
broken_down_question: str | None = None
divide_and_conquer: YesNoEnum | None = None
single_doc_id: str | None = None
search_type: KGSearchType | None = None
class SQLSimpleGenerationUpdate(LoggerUpdate):
sql_query: str | None = None
sql_query_results: list[Dict[Any, Any]] | None = None
individualized_sql_query: str | None = None
individualized_query_results: list[Dict[Any, Any]] | None = None
source_documents_sql: str | None = None
source_document_results: list[str] | None = None
updated_strategy: KGAnswerStrategy | None = None
class ConsolidatedResearchUpdate(LoggerUpdate):
consolidated_research_object_results_str: str | None = None
class DeepSearchFilterUpdate(LoggerUpdate):
vespa_filter_results: KGFilterConstructionResults | None = None
div_con_entities: list[str] | None = None
source_division: bool | None = None
global_entity_filters: list[str] | None = None
global_relationship_filters: list[str] | None = None
local_entity_filters: list[list[str]] | None = None
source_filters: list[str] | None = None
class ResearchObjectOutput(LoggerUpdate):
research_object_results: Annotated[list[dict[str, Any]], add] = []
class ERTExtractionUpdate(LoggerUpdate):
entities_types_str: str = ""
relationship_types_str: str = ""
extracted_entities_w_attributes: list[str] = []
extracted_entities_no_attributes: list[str] = []
extracted_relationships: list[str] = []
extracted_terms: list[str] = []
time_filter: str | None = None
kg_doc_temp_view_name: str | None = None
kg_rel_temp_view_name: str | None = None
class ResultsDataUpdate(LoggerUpdate):
query_results_data_str: str | None = None
individualized_query_results_data_str: str | None = None
reference_results_str: str | None = None
class ResearchObjectUpdate(LoggerUpdate):
research_object_results: Annotated[list[dict[str, Any]], add] = []
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
ERTExtractionUpdate,
AnalysisUpdate,
SQLSimpleGenerationUpdate,
ResultsDataUpdate,
ResearchObjectOutput,
DeepSearchFilterUpdate,
ResearchObjectUpdate,
ConsolidatedResearchUpdate,
):
pass
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]
class ResearchObjectInput(LoggerUpdate):
research_nr: int
entity: str
broken_down_question: str
vespa_filter_results: KGFilterConstructionResults
source_division: bool | None
source_entity_filters: list[str] | None

View File

@@ -0,0 +1,29 @@
from onyx.agents.agent_search.kb_search.models import KGSteps
STEP_DESCRIPTIONS: dict[int, KGSteps] = {
1: KGSteps(
description="Analyzing the question...",
activities=[
"Entities in Query",
"Relationships in Query",
"Terms in Query",
"Time Filters",
],
),
2: KGSteps(
description="Planning the response approach...",
activities=["Query Execution Strategy", "Answer Format"],
),
3: KGSteps(
description="Querying the Knowledge Graph...",
activities=[
"Knowledge Graph Query",
"Knowledge Graph Query Results",
"Query for Source Documents",
"Source Documents",
],
),
4: KGSteps(
description="Conducting further research on source documents...", activities=[]
),
}

View File

@@ -8,6 +8,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.context.search.models import RerankingDetails
from onyx.db.models import Persona
from onyx.file_store.utils import InMemoryChatFile
from onyx.kg.models import KGConfigSettings
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
@@ -70,6 +71,7 @@ class GraphSearchConfig(BaseModel):
allow_refinement: bool = True
skip_gen_ai_answer_generation: bool = False
allow_agent_reranking: bool = False
kg_config_settings: KGConfigSettings = KGConfigSettings()
class GraphConfig(BaseModel):

View File

@@ -18,6 +18,8 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
from onyx.agents.agent_search.deep_search.main.states import (
MainInput as MainInput,
)
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.chat.models import AgentAnswerPiece
@@ -88,7 +90,7 @@ def _parse_agent_event(
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
graph_input: BasicInput | MainInput | DCMainInput,
graph_input: BasicInput | MainInput | DCMainInput | KBMainInput,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
for event in compiled_graph.stream(
@@ -102,7 +104,7 @@ def manage_sync_streaming(
def run_graph(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
input: BasicInput | MainInput | DCMainInput,
input: BasicInput | MainInput | DCMainInput | KBMainInput,
) -> AnswerStream:
for event in manage_sync_streaming(
@@ -147,6 +149,21 @@ def run_basic_graph(
return run_graph(compiled_graph, config, input)
def run_kb_graph(
config: GraphConfig,
) -> AnswerStream:
graph = kb_graph_builder()
compiled_graph = graph.compile()
input = KBMainInput(log_messages=[])
yield ToolCallKickoff(
tool_name="agent_search_0",
tool_args={"query": config.inputs.prompt_builder.raw_user_query},
)
yield from run_graph(compiled_graph, config, input)
def run_dc_graph(
config: GraphConfig,
) -> AnswerStream:

View File

@@ -159,3 +159,8 @@ BaseMessage_Content = str | list[str | dict[str, Any]]
class QueryExpansionType(Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
class ReferenceResults(BaseModel):
citations: list[str]
general_entities: list[str]

View File

@@ -0,0 +1,105 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.kg_processing")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.kg_processing",
]
)

View File

@@ -298,5 +298,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.user_file_folder_sync",
"onyx.background.celery.tasks.kg_processing",
]
)

View File

@@ -0,0 +1,21 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_KG_PROCESSING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = CELERY_WORKER_KG_PROCESSING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -28,6 +28,24 @@ beat_task_templates: list[dict] = []
beat_task_templates.extend(
[
{
"name": "check-for-kg-processing",
"task": OnyxCeleryTask.CHECK_KG_PROCESSING,
"schedule": timedelta(seconds=60),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-kg-processing-clustering-only",
"task": OnyxCeleryTask.CHECK_KG_PROCESSING_CLUSTERING_ONLY,
"schedule": timedelta(seconds=600),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,

View File

@@ -0,0 +1,321 @@
import time
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.document import check_for_documents_needing_kg_clustering
from onyx.db.document import check_for_documents_needing_kg_processing
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.kg_config import get_kg_processing_in_progress_status
from onyx.db.kg_config import KGProcessingType
from onyx.db.kg_config import set_kg_processing_in_progress_status
from onyx.db.search_settings import get_current_search_settings
from onyx.kg.clustering.clustering import kg_clustering
from onyx.kg.extractions.extraction_processing import kg_extraction
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.CHECK_KG_PROCESSING,
soft_time_limit=300,
bind=True,
)
def check_for_kg_processing(self: Task, *, tenant_id: str) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
time_start = time.monotonic()
task_logger.warning("check_for_kg_processing - Starting")
tasks_created = 0
locked = False
redis_client = get_redis_client()
get_redis_replica_client()
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.KG_PROCESSING_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
locked = True
with get_session_with_current_tenant() as db_session:
kg_config = get_kg_config_settings(db_session)
if not kg_config.KG_ENABLED:
return None
kg_coverage_start = kg_config.KG_COVERAGE_START
kg_max_coverage_days = kg_config.KG_MAX_COVERAGE_DAYS
kg_extraction_in_progress = kg_config.KG_EXTRACTION_IN_PROGRESS
kg_clustering_in_progress = kg_config.KG_CLUSTERING_IN_PROGRESS
if kg_extraction_in_progress or kg_clustering_in_progress:
task_logger.info(
f"KG processing already in progress for tenant {tenant_id}, skipping"
)
return None
with get_session_with_current_tenant() as db_session:
documents_needing_kg_processing = check_for_documents_needing_kg_processing(
db_session, kg_coverage_start, kg_max_coverage_days
)
if not documents_needing_kg_processing:
task_logger.info(
f"No documents needing KG processing for tenant {tenant_id}, skipping"
)
return None
task_logger.info(
f"Found documents needing KG processing for tenant {tenant_id}"
)
self.app.send_task(
OnyxCeleryTask.KG_PROCESSING,
kwargs={
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.KG_PROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during kg processing check")
finally:
if locked:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_kg_processing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
redis_lock_dump(lock_beat, redis_client)
time_elapsed = time.monotonic() - time_start
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
return tasks_created
@shared_task(
name=OnyxCeleryTask.CHECK_KG_PROCESSING_CLUSTERING_ONLY,
soft_time_limit=300,
bind=True,
)
def check_for_kg_processing_clustering_only(
self: Task, *, tenant_id: str
) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
time_start = time.monotonic()
task_logger.warning("check_for_kg_processing_clustering_only - Starting")
tasks_created = 0
locked = False
redis_client = get_redis_client()
get_redis_replica_client()
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.KG_PROCESSING_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
locked = True
with get_session_with_current_tenant() as db_session:
kg_clustering_in_progress = get_kg_processing_in_progress_status(
db_session, processing_type=KGProcessingType.CLUSTERING
)
documents_needing_kg_clustering = check_for_documents_needing_kg_clustering(
db_session
)
if kg_clustering_in_progress:
task_logger.info(
f"KG clustering already in progress for tenant {tenant_id}, skipping"
)
return None
elif not documents_needing_kg_clustering:
task_logger.info(
f"No documents needing KG clustering for tenant {tenant_id}, skipping"
)
return None
task_logger.info(
f"Found documents needing KG processing for tenant {tenant_id}"
)
self.app.send_task(
OnyxCeleryTask.KG_CLUSTERING_ONLY,
kwargs={
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.KG_PROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during kg clustering-only check")
finally:
if locked:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_kg_processing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
redis_lock_dump(lock_beat, redis_client)
time_elapsed = time.monotonic() - time_start
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
return tasks_created
@shared_task(
name=OnyxCeleryTask.KG_PROCESSING,
soft_time_limit=1000,
bind=True,
)
def kg_processing(self: Task, *, tenant_id: str) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
time.monotonic()
task_logger.warning(f"check_for_kg_processing - Starting for tenant {tenant_id}")
task_logger.debug("Starting kg processing task!")
with get_session_with_current_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
index_str = search_settings.index_name
set_kg_processing_in_progress_status(
db_session, processing_type=KGProcessingType.EXTRACTION, in_progress=True
)
set_kg_processing_in_progress_status(
db_session, processing_type=KGProcessingType.CLUSTERING, in_progress=True
)
db_session.commit()
task_logger.info(f"KG processing set to in progress for tenant {tenant_id}")
try:
kg_extraction(
tenant_id=tenant_id, index_name=index_str, processing_chunk_batch_size=8
)
except Exception as e:
task_logger.exception(f"Error during kg extraction: {e}")
finally:
with get_session_with_current_tenant() as db_session:
set_kg_processing_in_progress_status(
db_session,
processing_type=KGProcessingType.EXTRACTION,
in_progress=False,
)
db_session.commit()
task_logger.debug("Completed kg extraction task. Moving to clustering")
try:
kg_clustering(
tenant_id=tenant_id, index_name=index_str, processing_chunk_batch_size=8
)
except Exception as e:
task_logger.exception(f"Error during kg clustering: {e}")
finally:
with get_session_with_current_tenant() as db_session:
set_kg_processing_in_progress_status(
db_session,
processing_type=KGProcessingType.CLUSTERING,
in_progress=False,
)
db_session.commit()
task_logger.debug("Completed kg clustering task!")
task_logger.debug("Completed kg clustering task!")
return None
@shared_task(
name=OnyxCeleryTask.KG_CLUSTERING_ONLY,
soft_time_limit=1000,
bind=True,
)
def kg_clustering_only(self: Task, *, tenant_id: str) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
time.monotonic()
with get_session_with_current_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
index_str = search_settings.index_name
set_kg_processing_in_progress_status(
db_session, processing_type=KGProcessingType.CLUSTERING, in_progress=True
)
db_session.commit()
task_logger.info(f"KG processing set to in progress for tenant {tenant_id}")
task_logger.debug("Starting kg clustering-only task!")
try:
kg_clustering(
tenant_id=tenant_id, index_name=index_str, processing_chunk_batch_size=8
)
except Exception as e:
task_logger.exception(f"Error during kg clustering: {e}")
finally:
with get_session_with_current_tenant() as db_session:
set_kg_processing_in_progress_status(
db_session,
processing_type=KGProcessingType.CLUSTERING,
in_progress=False,
)
db_session.commit()
task_logger.debug("Completed kg clustering task!")
task_logger.debug("Completed kg clustering task!")
return None

View File

@@ -23,6 +23,7 @@ from onyx.db.document import mark_document_as_modified
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.relationships import delete_document_references_from_kg
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
@@ -119,6 +120,11 @@ def document_by_cc_pair_cleanup_task(
chunk_count=chunk_count,
)
delete_document_references_from_kg(
db_session=db_session,
document_id=document_id,
)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],

View File

@@ -0,0 +1,18 @@
"""Factory stub for running celery worker for knowledge graph processing.
This code is different from the primary/beat stubs because there is no EE version to
fetch. Port over the code in those files if we add an EE version of this worker."""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.kg_processing import celery_app
return celery_app
app = get_app()

View File

@@ -12,6 +12,7 @@ from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_agent_search_graph
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_dc_graph
from onyx.agents.agent_search.run_graph import run_kb_graph
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
@@ -24,8 +25,10 @@ from onyx.chat.models import SubQuestionKey
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import AGENT_ALLOW_REFINEMENT
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
from onyx.configs.chat_configs import USE_DIV_CON_AGENT
from onyx.configs.constants import BASIC_KEY
from onyx.context.search.models import RerankingDetails
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.models import Persona
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
@@ -120,6 +123,7 @@ class Answer:
allow_refinement=AGENT_ALLOW_REFINEMENT,
allow_agent_reranking=allow_agent_reranking,
perform_initial_search_decomposition=INITIAL_SEARCH_DECOMPOSITION_ENABLED,
kg_config_settings=get_kg_config_settings(db_session),
)
self.graph_config = GraphConfig(
inputs=self.graph_inputs,
@@ -134,10 +138,17 @@ class Answer:
yield from self._processed_stream
return
if self.graph_config.behavior.use_agentic_search:
if self.graph_config.behavior.use_agentic_search and (
self.graph_config.inputs.persona
and self.graph_config.behavior.kg_config_settings.KG_ENABLED
and self.graph_config.inputs.persona.name.startswith("KG Dev")
):
run_langgraph = run_kb_graph
elif self.graph_config.behavior.use_agentic_search:
run_langgraph = run_agent_search_graph
elif (
self.graph_config.inputs.persona
and USE_DIV_CON_AGENT
and self.graph_config.inputs.persona.description.startswith(
"DivCon Beta Agent"
)

View File

@@ -79,6 +79,7 @@ from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
from onyx.db.chat import update_chat_session_updated_at_timestamp
from onyx.db.engine import get_session_context_manager
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists
from onyx.db.milestone import update_user_assistant_milestone
@@ -96,6 +97,15 @@ from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_all_chat_files
from onyx.file_store.utils import save_files
from onyx.kg.clustering.clustering import kg_clustering
from onyx.kg.configuration import populate_default_account_employee_definitions
from onyx.kg.configuration import populate_default_grounded_entity_types
from onyx.kg.extractions.extraction_processing import kg_extraction
from onyx.kg.resets.reset_extractions import reset_extraction_kg_index
from onyx.kg.resets.reset_index import reset_full_kg_index
from onyx.kg.resets.reset_normalizations import reset_normalization_kg_index
from onyx.kg.resets.reset_source import reset_source_kg_index
from onyx.kg.resets.reset_vespa import reset_vespa_kg_index
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
@@ -560,6 +570,57 @@ def stream_chat_message_objects(
llm: LLM
kg_config_settings = get_kg_config_settings(db_session)
if kg_config_settings.KG_ENABLED:
# Temporarily, until we have a draft UI for the KG Operations/Management
# get Vespa index
search_settings = get_current_search_settings(db_session)
index_str = search_settings.index_name
if new_msg_req.message == "kg_e":
kg_extraction(tenant_id, index_str)
raise Exception("Extractions done")
elif new_msg_req.message == "kg_c":
kg_clustering(tenant_id, index_str)
raise Exception("Clustering done")
elif new_msg_req.message == "kg":
reset_vespa_kg_index(tenant_id, index_str)
reset_full_kg_index()
kg_extraction(tenant_id, index_str)
kg_clustering(tenant_id, index_str)
raise Exception("Full KG index reset done")
elif new_msg_req.message == "kg_rs_full":
reset_full_kg_index()
raise Exception("Full KG index reset done")
elif new_msg_req.message == "kg_rs_extraction":
reset_extraction_kg_index()
raise Exception("Extraction KG index reset done")
elif new_msg_req.message == "kg_rs_normalization":
reset_normalization_kg_index()
raise Exception("Normalization KG index reset done")
elif new_msg_req.message.startswith("kg_rs_source:"):
source_name = new_msg_req.message.split(":")[1].strip()
reset_source_kg_index(source_name, tenant_id, index_str)
raise Exception(f"KG index reset for source {source_name} done")
elif new_msg_req.message == "kg_rs_vespa":
reset_vespa_kg_index(tenant_id, index_str)
raise Exception("Vespa KG index reset done")
elif new_msg_req.message == "kg_setup":
populate_default_grounded_entity_types()
populate_default_account_employee_definitions()
raise Exception("KG setup done")
try:
# Move these variables inside the try block
user_id = user.id if user is not None else None

View File

@@ -184,6 +184,13 @@ POSTGRES_API_SERVER_POOL_OVERFLOW = int(
os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10
)
POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE = int(
os.environ.get("POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE") or 10
)
POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW = int(
os.environ.get("POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW") or 5
)
# defaults to False
# generally should only be used for
POSTGRES_USE_NULL_POOL = os.environ.get("POSTGRES_USE_NULL_POOL", "").lower() == "true"
@@ -309,6 +316,11 @@ try:
except ValueError:
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
)
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
VESPA_SYNC_MAX_TASKS = 1024
@@ -746,3 +758,9 @@ IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
DISABLE_AUTO_AUTH_REFRESH = (
os.environ.get("DISABLE_AUTO_AUTH_REFRESH", "").lower() == "true"
)
# Knowledge Graph Read Only User Configuration
DB_READONLY_USER: str = os.environ.get("DB_READONLY_USER", "db_readonly_user")
DB_READONLY_PASSWORD: str = urllib.parse.quote_plus(
os.environ.get("DB_READONLY_PASSWORD") or "password"
)

View File

@@ -102,3 +102,5 @@ USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH = (
os.environ.get("USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH", "false").lower()
== "true"
)
USE_DIV_CON_AGENT = os.environ.get("USE_DIV_CON_AGENT", "false").lower() == "true"

View File

@@ -68,6 +68,7 @@ POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
@@ -324,6 +325,9 @@ class OnyxCeleryQueues:
# Monitoring queue
MONITORING = "monitoring"
# KG processing queue
KG_PROCESSING = "kg_processing"
class OnyxRedisLocks:
PRIMARY_WORKER = "da_lock:primary_worker"
@@ -357,6 +361,9 @@ class OnyxRedisLocks:
CLOUD_BEAT_TASK_GENERATOR_LOCK = "da_lock:cloud_beat_task_generator"
CLOUD_CHECK_ALEMBIC_BEAT_LOCK = "da_lock:cloud_check_alembic"
# KG processing
KG_PROCESSING_LOCK = "da_lock:kg_processing"
class OnyxRedisSignals:
BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences"
@@ -373,6 +380,9 @@ class OnyxRedisSignals:
"signal:block_validate_connector_deletion_fences"
)
# KG processing
CHECK_KG_PROCESSING_BEAT_LOCK = "da_lock:check_kg_processing_beat"
class OnyxRedisConstants:
ACTIVE_FENCES = "active_fences"
@@ -456,6 +466,12 @@ class OnyxCeleryTask:
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"
# KG processing
CHECK_KG_PROCESSING = "check_kg_processing"
KG_PROCESSING = "kg_processing"
KG_CLUSTERING_ONLY = "kg_clustering_only"
CHECK_KG_PROCESSING_CLUSTERING_ONLY = "check_kg_processing_clustering_only"
# this needs to correspond to the matching entry in supervisord
ONYX_CELERY_BEAT_HEARTBEAT_KEY = "onyx:celery:beat:heartbeat"
@@ -468,3 +484,8 @@ if platform.system() == "Darwin":
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
else:
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore
class OnyxCallTypes(str, Enum):
FIREFLIES = "fireflies"
GONG = "gong"

View File

@@ -0,0 +1,102 @@
import os
KG_RESEARCH_NUM_RETRIEVED_DOCS: int = int(
os.environ.get("KG_RESEARCH_NUM_RETRIEVED_DOCS", "25")
)
KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES: int = int(
os.environ.get("KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES", "10")
)
KG_ENTITY_EXTRACTION_TIMEOUT: int = int(
os.environ.get("KG_ENTITY_EXTRACTION_TIMEOUT", "15")
)
KG_RELATIONSHIP_EXTRACTION_TIMEOUT: int = int(
os.environ.get("KG_RELATIONSHIP_EXTRACTION_TIMEOUT", "15")
)
KG_STRATEGY_GENERATION_TIMEOUT: int = int(
os.environ.get("KG_STRATEGY_GENERATION_TIMEOUT", "20")
)
KG_SQL_GENERATION_TIMEOUT: int = int(os.environ.get("KG_SQL_GENERATION_TIMEOUT", "25"))
KG_FILTER_CONSTRUCTION_TIMEOUT: int = int(
os.environ.get("KG_FILTER_CONSTRUCTION_TIMEOUT", "15")
)
KG_NORMALIZATION_RETRIEVE_ENTITIES_LIMIT: int = int(
os.environ.get("KG_NORMALIZATION_RETRIEVE_ENTITIES_LIMIT", "100")
)
KG_FILTERED_SEARCH_TIMEOUT: int = int(
os.environ.get("KG_FILTERED_SEARCH_TIMEOUT", "30")
)
KG_OBJECT_SOURCE_RESEARCH_TIMEOUT: int = int(
os.environ.get("KG_OBJECT_SOURCE_RESEARCH_TIMEOUT", "30")
)
KG_ANSWER_GENERATION_TIMEOUT: int = int(
os.environ.get("KG_ANSWER_GENERATION_TIMEOUT", "30")
)
KG_MAX_DEEP_SEARCH_RESULTS: int = int(
os.environ.get("KG_MAX_DEEP_SEARCH_RESULTS", "30")
)
KG_DEFAULT_MAX_PARENT_RECURSION_DEPTH: int = int(
os.environ.get("KG_DEFAULT_MAX_PARENT_RECURSION_DEPTH", "2")
)
_KG_NORMALIZATION_RERANK_UNIGRAM_WEIGHT: float = max(
1e-3,
min(1, float(os.environ.get("KG_NORMALIZATION_RERANK_UNIGRAM_WEIGHT", "0.25"))),
)
_KG_NORMALIZATION_RERANK_BIGRAM_WEIGHT: float = max(
1e-3,
min(1, float(os.environ.get("KG_NORMALIZATION_RERANK_BIGRAM_WEIGHT", "0.25"))),
)
_KG_NORMALIZATION_RERANK_TRIGRAM_WEIGHT: float = max(
1e-3,
min(1, float(os.environ.get("KG_NORMALIZATION_RERANK_TRIGRAM_WEIGHT", "0.5"))),
)
_KG_NORMALIZATION_RERANK_NGRAM_SUMS: float = (
_KG_NORMALIZATION_RERANK_UNIGRAM_WEIGHT
+ _KG_NORMALIZATION_RERANK_BIGRAM_WEIGHT
+ _KG_NORMALIZATION_RERANK_TRIGRAM_WEIGHT
)
KG_NORMALIZATION_RERANK_NGRAM_WEIGHTS: tuple[float, float, float] = (
_KG_NORMALIZATION_RERANK_UNIGRAM_WEIGHT / _KG_NORMALIZATION_RERANK_NGRAM_SUMS,
_KG_NORMALIZATION_RERANK_BIGRAM_WEIGHT / _KG_NORMALIZATION_RERANK_NGRAM_SUMS,
_KG_NORMALIZATION_RERANK_TRIGRAM_WEIGHT / _KG_NORMALIZATION_RERANK_NGRAM_SUMS,
)
KG_NORMALIZATION_RERANK_LEVENSHTEIN_WEIGHT: float = max(
0,
min(1, float(os.environ.get("KG_NORMALIZATION_RERANK_LEVENSHTEIN_WEIGHT", "0.25"))),
)
KG_NORMALIZATION_RERANK_THRESHOLD: float = float(
os.environ.get("KG_NORMALIZATION_RERANK_THRESHOLD", "0.3")
)
KG_CLUSTERING_RETRIEVE_THRESHOLD: float = float(
os.environ.get("KG_CLUSTERING_RETRIEVE_THRESHOLD", "0.6")
)
KG_CLUSTERING_THRESHOLD: float = float(
os.environ.get("KG_CLUSTERING_THRESHOLD", "0.96")
)

View File

@@ -193,12 +193,19 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
team {
name
}
assignee {
email
}
previousIdentifiers
subIssueSortOrder
priorityLabel
identifier
url
branchName
state {
id
name
}
customerTicketCount
description
comments {
@@ -267,7 +274,19 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
title=node["title"],
doc_updated_at=time_str_to_utc(node["updatedAt"]),
metadata={
"team": node["team"]["name"],
k: str(v)
for k, v in {
"team": (node.get("team") or {}).get("name"),
"assignee": (node.get("assignee") or {}).get("email"),
"state": (node.get("state") or {}).get("name"),
"priority": node.get("priority"),
"estimate": node.get("estimate"),
"started_at": node.get("startedAt"),
"completed_at": node.get("completedAt"),
"created_at": node.get("createdAt"),
"due_date": node.get("dueDate"),
}.items()
if v is not None
},
)
)

View File

@@ -134,29 +134,26 @@ def process_jira_issue(
metadata_dict: dict[str, str | list[str]] = {}
people = set()
try:
creator = best_effort_get_field_from_issue(issue, _FIELD_REPORTER)
if basic_expert_info := best_effort_basic_expert_info(creator):
people.add(basic_expert_info)
metadata_dict[_FIELD_REPORTER] = basic_expert_info.get_semantic_name()
if email := basic_expert_info.get_email():
metadata_dict[_FIELD_REPORTER_EMAIL] = email
except Exception:
# Author should exist but if not, doesn't matter
pass
creator = best_effort_get_field_from_issue(issue, _FIELD_REPORTER)
if creator is not None and (
basic_expert_info := best_effort_basic_expert_info(creator)
):
people.add(basic_expert_info)
metadata_dict[_FIELD_REPORTER] = basic_expert_info.get_semantic_name()
if email := basic_expert_info.get_email():
metadata_dict[_FIELD_REPORTER_EMAIL] = email
try:
assignee = best_effort_get_field_from_issue(issue, _FIELD_ASSIGNEE)
if basic_expert_info := best_effort_basic_expert_info(assignee):
people.add(basic_expert_info)
metadata_dict[_FIELD_ASSIGNEE] = basic_expert_info.get_semantic_name()
if email := basic_expert_info.get_email():
metadata_dict[_FIELD_ASSIGNEE_EMAIL] = email
except Exception:
# Author should exist but if not, doesn't matter
pass
assignee = best_effort_get_field_from_issue(issue, _FIELD_ASSIGNEE)
if assignee is not None and (
basic_expert_info := best_effort_basic_expert_info(assignee)
):
people.add(basic_expert_info)
metadata_dict[_FIELD_ASSIGNEE] = basic_expert_info.get_semantic_name()
if email := basic_expert_info.get_email():
metadata_dict[_FIELD_ASSIGNEE_EMAIL] = email
metadata_dict[_FIELD_KEY] = issue.key
if priority := best_effort_get_field_from_issue(issue, _FIELD_PRIORITY):
metadata_dict[_FIELD_PRIORITY] = priority.name
if status := best_effort_get_field_from_issue(issue, _FIELD_STATUS):
@@ -178,20 +175,17 @@ def process_jira_issue(
):
metadata_dict[_FIELD_RESOLUTION_DATE_KEY] = resolutiondate
try:
parent = best_effort_get_field_from_issue(issue, _FIELD_PARENT)
if parent:
metadata_dict[_FIELD_PARENT] = parent.key
except Exception:
# Parent should exist but if not, doesn't matter
pass
try:
project = best_effort_get_field_from_issue(issue, _FIELD_PROJECT)
if project:
metadata_dict[_FIELD_PROJECT_NAME] = project.name
metadata_dict[_FIELD_PROJECT] = project.key
except Exception:
# Project should exist.
parent = best_effort_get_field_from_issue(issue, _FIELD_PARENT)
if parent is not None:
metadata_dict[_FIELD_PARENT] = parent.key
project = best_effort_get_field_from_issue(issue, _FIELD_PROJECT)
if project is not None:
metadata_dict[_FIELD_PROJECT_NAME] = project.name
metadata_dict[_FIELD_PROJECT] = project.key
else:
logger.error(f"Project should exist but does not for {issue.key}")
return Document(

View File

@@ -23,15 +23,19 @@ JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
def best_effort_basic_expert_info(obj: Any) -> BasicExpertInfo | None:
display_name = None
email = None
if hasattr(obj, "displayName"):
display_name = obj.displayName
else:
display_name = obj.get("displayName")
if hasattr(obj, "emailAddress"):
email = obj.emailAddress
else:
email = obj.get("emailAddress")
try:
if hasattr(obj, "displayName"):
display_name = obj.displayName
else:
display_name = obj.get("displayName")
if hasattr(obj, "emailAddress"):
email = obj.emailAddress
else:
email = obj.get("emailAddress")
except Exception:
return None
if not email and not display_name:
return None

View File

@@ -35,6 +35,28 @@ logger = setup_logger()
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
"Opportunity": {
"Account": "account",
"FiscalQuarter": "fiscal_quarter",
"FiscalYear": "fiscal_year",
"IsClosed": "is_closed",
"Name": "name",
"StageName": "stage_name",
"Type": "type",
"Amount": "amount",
"CloseDate": "close_date",
"Probability": "probability",
"CreatedDate": "created_date",
"LastModifiedDate": "last_modified_date",
},
"Contact": {
"Account": "account",
"CreatedDate": "created_date",
"LastModifiedDate": "last_modified_date",
},
}
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
"""Approach outline
@@ -170,6 +192,8 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# Always want to make sure user is grabbed for permissioning purposes
all_types.add("User")
# Always want to make sure account is grabbed for reference purposes
all_types.add("Account")
logger.info(f"All object types: num={len(all_types)} list={all_types}")
@@ -282,7 +306,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
f"remaining={len(updated_ids) - docs_processed}"
)
for parent_id in parent_id_batch:
parent_object = sf_db.get_record(parent_id, parent_type)
parent_object = sf_db.get_record(
parent_id, parent_type, isChild=False
)
if not parent_object:
logger.warning(
f"Failed to get parent object {parent_id} for {parent_type}"
@@ -294,6 +320,18 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
sf_object=parent_object,
sf_instance=self.sf_client.sf_instance,
)
doc.metadata["object_type"] = parent_type
# Add default attributes to the metadata
for (
sf_attribute,
canonical_attribute,
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
if sf_attribute in parent_object.data:
doc.metadata[canonical_attribute] = parent_object.data[
sf_attribute
]
doc_sizeof = sys.getsizeof(doc)
docs_to_yield_bytes += doc_sizeof
docs_to_yield.append(doc)

View File

@@ -172,7 +172,7 @@ def convert_sf_object_to_doc(
sections = [_extract_section(sf_object, base_url)]
for id in sf_db.get_child_ids(sf_object.id):
if not (child_object := sf_db.get_record(id)):
if not (child_object := sf_db.get_record(id, isChild=True)):
continue
sections.append(_extract_section(child_object, base_url))

View File

@@ -456,7 +456,7 @@ class OnyxSalesforceSQLite:
return result[0]
def get_record(
self, object_id: str, object_type: str | None = None
self, object_id: str, object_type: str | None = None, isChild: bool = False
) -> SalesforceObject | None:
"""Retrieve the record and return it as a SalesforceObject."""
if self._conn is None:
@@ -469,15 +469,44 @@ class OnyxSalesforceSQLite:
with self._conn:
cursor = self._conn.cursor()
cursor.execute(
"SELECT data FROM salesforce_objects WHERE id = ?", (object_id,)
)
result = cursor.fetchone()
# Get the object data and account data
if object_type == "Account" or isChild:
cursor.execute(
"SELECT data FROM salesforce_objects WHERE id = ?", (object_id,)
)
else:
cursor.execute(
"SELECT pso.data, r.parent_id as parent_id, sso.object_type FROM salesforce_objects pso \
LEFT JOIN relationships r on r.child_id = pso.id \
LEFT JOIN salesforce_objects sso on r.parent_id = sso.id \
WHERE pso.id = ? ",
(object_id,),
)
result = cursor.fetchall()
if not result:
logger.warning(f"Object ID {object_id} not found")
return None
data = json.loads(result[0])
data = json.loads(result[0][0])
if object_type != "Account":
# convert any account ids of the relationships back into data fields, with name
for row in result:
# the following skips Account objects.
if len(row) < 3:
continue
if row[1] and row[2] and row[2] == "Account":
data["AccountId"] = row[1]
cursor.execute(
"SELECT data FROM salesforce_objects WHERE id = ?",
(row[1],),
)
account_data = json.loads(cursor.fetchone()[0])
data["Account"] = account_data.get("Name", "")
return SalesforceObject(id=object_id, type=object_type, data=data)
def find_ids_by_type(self, object_type: str) -> list[str]:

View File

@@ -111,6 +111,11 @@ class BaseFilters(BaseModel):
document_set: list[str] | None = None
time_cutoff: datetime | None = None
tags: list[Tag] | None = None
kg_entities: list[str] | None = None
kg_relationships: list[str] | None = None
kg_terms: list[str] | None = None
kg_sources: list[str] | None = None
kg_chunk_id_zero_only: bool | None = False
class UserFileFilters(BaseModel):

View File

@@ -183,6 +183,11 @@ def retrieval_preprocessing(
tags=preset_filters.tags, # Tags are never auto-extracted
access_control_list=user_acl_filters,
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
kg_entities=preset_filters.kg_entities,
kg_relationships=preset_filters.kg_relationships,
kg_terms=preset_filters.kg_terms,
kg_sources=preset_filters.kg_sources,
kg_chunk_id_zero_only=preset_filters.kg_chunk_id_zero_only,
)
llm_evaluation_type = LLMEvaluationType.BASIC

View File

@@ -16,6 +16,7 @@ from onyx.db.enums import IndexingMode
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.kg.models import KGConnectorData
from onyx.server.documents.models import ConnectorBase
from onyx.server.documents.models import ObjectCreationIdResponse
from onyx.server.models import StatusResponse
@@ -334,3 +335,29 @@ def mark_ccpair_with_indexing_trigger(
except Exception:
db_session.rollback()
raise
def get_kg_enabled_connectors(db_session: Session) -> list[KGConnectorData]:
"""
Retrieves a list of connector IDs that have not been KG processed for a given tenant.
Args:
db_session (Session): The database session to use
Returns:
list[KGConnectorData]: List of connector IDs with KG extraction enabled but have unprocessed documents
"""
try:
stmt = select(Connector.id, Connector.source, Connector.kg_coverage_days).where(
Connector.kg_processing_enabled
)
result = db_session.execute(stmt)
connector_results = [
KGConnectorData(id=row[0], source=row[1].lower(), kg_coverage_days=row[2])
for row in result.fetchall()
]
return connector_results
except Exception as e:
logger.error(f"Error fetching unprocessed connector IDs: {str(e)}")
raise e

View File

@@ -4,6 +4,7 @@ from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from sqlalchemy import and_
@@ -21,23 +22,36 @@ from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import null
from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.configs.kg_configs import KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
from onyx.db.entities import delete_from_kg_entities__no_commit
from onyx.db.entities import delete_from_kg_entities_extraction_staging__no_commit
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Document
from onyx.db.models import Document as DbDocument
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import KGEntity
from onyx.db.models import KGRelationship
from onyx.db.models import User
from onyx.db.relationships import delete_from_kg_relationships__no_commit
from onyx.db.relationships import (
delete_from_kg_relationships_extraction_staging__no_commit,
)
from onyx.db.tag import delete_document_tags_for_documents__no_commit
from onyx.db.utils import model_to_dict
from onyx.document_index.interfaces import DocumentMetadata
from onyx.kg.models import KGStage
from onyx.kg.utils.formatting_utils import split_entity_id
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from onyx.utils.logger import setup_logger
@@ -392,6 +406,7 @@ def upsert_documents(
last_modified=datetime.now(timezone.utc),
primary_owners=doc.primary_owners,
secondary_owners=doc.secondary_owners,
kg_stage=KGStage.NOT_STARTED,
)
)
for doc in seen_documents.values()
@@ -605,7 +620,29 @@ def delete_documents_complete__no_commit(
) -> None:
"""This completely deletes the documents from the db, including all foreign key relationships"""
# Start by deleting the chunk stats for the documents
# Start with the kg references
delete_from_kg_relationships__no_commit(
db_session=db_session,
document_ids=document_ids,
)
delete_from_kg_entities__no_commit(
db_session=db_session,
document_ids=document_ids,
)
delete_from_kg_relationships_extraction_staging__no_commit(
db_session=db_session,
document_ids=document_ids,
)
delete_from_kg_entities_extraction_staging__no_commit(
db_session=db_session,
document_ids=document_ids,
)
# Continue with deleting the chunk stats for the documents
delete_chunk_stats_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids,
@@ -858,3 +895,374 @@ def fetch_chunk_count_for_document(
) -> int | None:
stmt = select(DbDocument.chunk_count).where(DbDocument.id == document_id)
return db_session.execute(stmt).scalar_one_or_none()
def get_unprocessed_kg_document_batch_for_connector(
db_session: Session,
connector_id: int,
kg_coverage_start: datetime,
kg_max_coverage_days: int,
batch_size: int = 100,
) -> list[DbDocument]:
"""
Retrieves a batch of documents that have not been processed for knowledge graph extraction.
Args:
db_session (Session): The database session to use
connector_id (int): The ID of the connector to get documents for
batch_size (int): The maximum number of documents to retrieve
Returns:
list[DbDocument]: List of documents that need KG processing
"""
stmt = (
select(DbDocument)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DbDocument.doc_updated_at >= kg_coverage_start,
DbDocument.doc_updated_at
>= datetime.now() - timedelta(days=kg_max_coverage_days),
or_(
DbDocument.kg_stage.is_(None),
DbDocument.kg_stage == KGStage.NOT_STARTED,
DbDocument.doc_updated_at > DbDocument.kg_processing_time,
),
)
)
.distinct()
.order_by(DbDocument.doc_updated_at.desc())
.limit(batch_size)
)
documents = db_session.scalars(stmt).all()
db_session.flush()
return list(documents)
def get_kg_extracted_document_ids(db_session: Session) -> list[str]:
"""
Retrieves all document IDs where kg_stage is EXTRACTED.
Args:
db_session (Session): The database session to use
Returns:
list[str]: List of document IDs that have been KG processed
"""
stmt = select(DbDocument.id).where(DbDocument.kg_stage == KGStage.EXTRACTED)
return list(db_session.scalars(stmt).all())
def update_document_kg_info(
db_session: Session, document_id: str, kg_stage: KGStage
) -> None:
"""Updates the knowledge graph related information for a document.
Args:
db_session (Session): The database session to use
document_id (str): The ID of the document to update
kg_stage (KGStage): The stage of the knowledge graph processing for the document
Raises:
ValueError: If the document with the given ID is not found
"""
stmt = (
update(DbDocument)
.where(DbDocument.id == document_id)
.values(
kg_stage=kg_stage,
kg_processing_time=datetime.now(timezone.utc),
)
)
db_session.execute(stmt)
def update_document_kg_stage(
db_session: Session,
document_id: str,
kg_stage: KGStage,
) -> None:
stmt = (
update(DbDocument).where(DbDocument.id == document_id).values(kg_stage=kg_stage)
)
db_session.execute(stmt)
db_session.flush()
def get_all_kg_extracted_documents_info(
db_session: Session,
) -> list[str]:
"""Retrieves the knowledge graph data for all documents that have been processed.
Args:
db_session (Session): The database session to use
Returns:
List[Tuple[str, dict]]: A list of tuples containing:
- str: The document ID
- dict: The KG data containing 'entities', 'relationships', and 'terms'
Only returns documents where kg_stage is EXTRACTED
"""
stmt = (
select(DbDocument.id)
.where(DbDocument.kg_stage == KGStage.EXTRACTED)
.order_by(DbDocument.id)
)
results = db_session.execute(stmt).all()
return [str(doc_id) for doc_id in results]
def get_base_llm_doc_information(
db_session: Session, document_ids: list[str]
) -> list[str]:
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
results = db_session.execute(stmt).all()
documents = []
for doc_nr, doc in enumerate(results):
bare_doc = doc[0]
documents.append(
f"""* [{bare_doc.semantic_id}]({bare_doc.link}) ({bare_doc.doc_updated_at})"""
)
return documents[:KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES]
def get_document_updated_at(
document_id: str,
db_session: Session,
) -> datetime | None:
"""Retrieves the doc_updated_at timestamp for a given document ID.
Args:
document_id (str): The ID of the document to query
db_session (Session): The database session to use
Returns:
Optional[datetime]: The doc_updated_at timestamp if found, None if document doesn't exist
"""
parts = split_entity_id(document_id)
if len(parts) == 2:
document_id = parts[1]
elif len(parts) > 2:
raise ValueError(f"Invalid document ID: {document_id}")
stmt = select(DbDocument.doc_updated_at).where(DbDocument.id == document_id)
return db_session.execute(stmt).scalar_one_or_none()
def reset_all_document_kg_stages(db_session: Session) -> int:
"""Reset the KG stage of all documents that are not in NOT_STARTED state to NOT_STARTED.
Args:
db_session (Session): The database session to use
Returns:
int: Number of documents that were reset
"""
stmt = (
update(DbDocument)
.where(DbDocument.kg_stage != KGStage.NOT_STARTED)
.values(kg_stage=KGStage.NOT_STARTED)
)
result = db_session.execute(stmt)
# The hasattr check is needed for type checking, even though rowcount
# is guaranteed to exist at runtime for UPDATE operations
return result.rowcount if hasattr(result, "rowcount") else 0
def update_document_kg_stages(
db_session: Session, source_stage: KGStage, target_stage: KGStage
) -> int:
"""Reset the KG stage only of documents back to NOT_STARTED.
Part of reset flow for documents that have been extracted but not clustered.
Args:
db_session (Session): The database session to use
Returns:
int: Number of documents that were reset
"""
stmt = (
update(DbDocument)
.where(DbDocument.kg_stage == source_stage)
.values(kg_stage=target_stage)
)
result = db_session.execute(stmt)
# The hasattr check is needed for type checking, even though rowcount
# is guaranteed to exist at runtime for UPDATE operations
return result.rowcount if hasattr(result, "rowcount") else 0
def get_skipped_kg_documents(db_session: Session) -> list[str]:
"""
Retrieves all document IDs where kg_stage is SKIPPED.
Args:
db_session (Session): The database session to use
Returns:
list[str]: List of document IDs that have been skipped in KG processing
"""
stmt = select(DbDocument.id).where(DbDocument.kg_stage == KGStage.SKIPPED)
return list(db_session.scalars(stmt).all())
def get_kg_doc_info_for_entity_name(
db_session: Session, document_id: str, entity_type: str
) -> KGEntityDocInfo:
"""
Get the semantic ID and the link for an entity name.
"""
result = (
db_session.query(Document.semantic_id, Document.link)
.filter(Document.id == document_id)
.first()
)
if result is None:
return KGEntityDocInfo(
doc_id=None,
doc_semantic_id=None,
doc_link=None,
semantic_entity_name=f"{entity_type}:{document_id}",
semantic_linked_entity_name=f"{entity_type}:{document_id}",
)
return KGEntityDocInfo(
doc_id=document_id,
doc_semantic_id=result[0],
doc_link=result[1],
semantic_entity_name=f"{entity_type.upper()}:{result[0]}",
semantic_linked_entity_name=f"[{entity_type.upper()}:{result[0]}]({result[1]})",
)
def check_for_documents_needing_kg_processing(
db_session: Session, kg_coverage_start: datetime, kg_max_coverage_days: int
) -> bool:
"""Check if there are any documents that need KG processing.
A document needs KG processing if:
1. It is associated with a connector that has kg_processing_enabled = true
2. AND either:
- Its kg_stage is NOT_STARTED or NULL
- OR its last_updated timestamp is greater than its kg_processing_time
Args:
db_session (Session): The database session to use
Returns:
bool: True if there are any documents needing KG processing, False otherwise
"""
stmt = (
select(1)
.select_from(DbDocument)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.join(
Connector,
DocumentByConnectorCredentialPair.connector_id == Connector.id,
)
.where(
and_(
Connector.kg_processing_enabled.is_(True),
DbDocument.doc_updated_at >= kg_coverage_start,
DbDocument.doc_updated_at
>= datetime.now() - timedelta(days=kg_max_coverage_days),
or_(
or_(
DbDocument.kg_stage.is_(None),
DbDocument.kg_stage == KGStage.NOT_STARTED,
),
DbDocument.doc_updated_at > DbDocument.kg_processing_time,
),
)
)
.exists()
)
return db_session.execute(select(stmt)).scalar() or False
def check_for_documents_needing_kg_clustering(db_session: Session) -> bool:
"""Check if there are any documents that need KG clustering.
A document needs KG clustering if:
1. It is associated with a connector that has kg_processing_enabled = true
2. AND either:
- Its kg_stage is EXTRACTED
- OR its last_updated timestamp is greater than its kg_processing_time
Args:
db_session (Session): The database session to use
Returns:
bool: True if there are any documents needing KG clustering, False otherwise
"""
stmt = (
select(1)
.select_from(DbDocument)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.join(
ConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.join(
Connector,
ConnectorCredentialPair.connector_id == Connector.id,
)
.where(
and_(
Connector.kg_processing_enabled.is_(True),
ConnectorCredentialPair.status
!= ConnectorCredentialPairStatus.DELETING,
or_(
DbDocument.kg_stage == KGStage.EXTRACTED,
DbDocument.last_modified > DbDocument.kg_processing_time,
),
)
)
.exists()
)
return db_session.execute(select(stmt)).scalar() or False
def get_document_kg_entities_and_relationships(
db_session: Session, document_id: str
) -> tuple[list[KGEntity], list[KGRelationship]]:
"""
Get the KG entities and relationships that references the document.
"""
entities = (
db_session.query(KGEntity).filter(KGEntity.document_id == document_id).all()
)
if not entities:
return [], []
entity_id_names = [entity.id_name for entity in entities]
relationships = (
db_session.query(KGRelationship)
.filter(
or_(
KGRelationship.source_node.in_(entity_id_names),
KGRelationship.target_node.in_(entity_id_names),
KGRelationship.source_document == document_id,
)
)
.all()
)
return entities, relationships

View File

@@ -27,6 +27,8 @@ from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from onyx.configs.app_configs import AWS_REGION_NAME
from onyx.configs.app_configs import DB_READONLY_PASSWORD
from onyx.configs.app_configs import DB_READONLY_USER
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
@@ -187,7 +189,9 @@ def is_valid_schema_name(name: str) -> bool:
class SqlEngine:
_engine: Engine | None = None
_readonly_engine: Engine | None = None
_lock: threading.Lock = threading.Lock()
_readonly_lock: threading.Lock = threading.Lock()
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
@classmethod
@@ -252,12 +256,80 @@ class SqlEngine:
cls._engine = engine
@classmethod
def init_readonly_engine(
cls,
pool_size: int,
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
max_overflow: int,
**extra_engine_kwargs: Any,
) -> None:
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
important args, and if incorrectly specified, we have run into hitting the pool
limit / using too many connections and overwhelming the database."""
with cls._readonly_lock:
if cls._readonly_engine:
return
if not DB_READONLY_USER or not DB_READONLY_PASSWORD:
raise ValueError(
"Custom database user credentials not configured in environment variables"
)
# Build connection string with custom user
connection_string = build_connection_string(
user=DB_READONLY_USER,
password=DB_READONLY_PASSWORD,
use_iam_auth=False, # Custom users typically don't use IAM auth
db_api=SYNC_DB_API, # Explicitly use sync DB API
)
# Start with base kwargs that are valid for all pool types
final_engine_kwargs: dict[str, Any] = {}
if POSTGRES_USE_NULL_POOL:
# if null pool is specified, then we need to make sure that
# we remove any passed in kwargs related to pool size that would
# cause the initialization to fail
final_engine_kwargs.update(extra_engine_kwargs)
final_engine_kwargs["poolclass"] = pool.NullPool
if "pool_size" in final_engine_kwargs:
del final_engine_kwargs["pool_size"]
if "max_overflow" in final_engine_kwargs:
del final_engine_kwargs["max_overflow"]
else:
final_engine_kwargs["pool_size"] = pool_size
final_engine_kwargs["max_overflow"] = max_overflow
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
# any passed in kwargs override the defaults
final_engine_kwargs.update(extra_engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
# echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:
event.listen(engine, "do_connect", provide_iam_token)
cls._readonly_engine = engine
@classmethod
def get_engine(cls) -> Engine:
if not cls._engine:
raise RuntimeError("Engine not initialized. Must call init_engine first.")
return cls._engine
@classmethod
def get_readonly_engine(cls) -> Engine:
if not cls._readonly_engine:
raise RuntimeError(
"Readonly engine not initialized. Must call init_readonly_engine first."
)
return cls._readonly_engine
@classmethod
def set_app_name(cls, app_name: str) -> None:
cls._app_name = app_name
@@ -307,6 +379,10 @@ def get_sqlalchemy_engine() -> Engine:
return SqlEngine.get_engine()
def get_readonly_sqlalchemy_engine() -> Engine:
return SqlEngine.get_readonly_engine()
async def get_async_connection() -> Any:
"""
Custom connection function for async engine when using IAM auth.
@@ -444,6 +520,9 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
except Exception as e:
logger.warning(f"Failed to reset search path: {e}")
connection.rollback()
finally:
cursor.close()
@@ -561,3 +640,49 @@ def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) ->
region = os.getenv("AWS_REGION_NAME", "us-east-2")
# Configure for psycopg2 with IAM token
configure_psycopg2_iam_auth(cparams, host, port, user, region)
@contextmanager
def get_db_readonly_user_session_with_current_tenant() -> (
Generator[Session, None, None]
):
"""
Generate a database session using a custom database user for the current tenant.
The custom user credentials are obtained from environment variables.
"""
tenant_id = get_current_tenant_id()
readonly_engine = get_readonly_sqlalchemy_engine()
event.listen(readonly_engine, "checkout", _set_search_path_on_checkout__listener)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
with readonly_engine.connect() as connection:
dbapi_connection = connection.connection
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
finally:
cursor.close()
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
except Exception as e:
logger.warning(f"Failed to reset search path: {e}")
connection.rollback()
finally:
cursor.close()

309
backend/onyx/db/entities.py Normal file
View File

@@ -0,0 +1,309 @@
import uuid
from datetime import datetime
from datetime import timezone
from typing import List
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session
import onyx.db.document as dbdocument
from onyx.db.models import Document
from onyx.db.models import KGEntity
from onyx.db.models import KGEntityExtractionStaging
from onyx.db.models import KGEntityType
from onyx.kg.models import KGGroundingType
from onyx.kg.models import KGStage
from onyx.kg.utils.formatting_utils import make_entity_id
def upsert_staging_entity(
db_session: Session,
name: str,
entity_type: str,
document_id: str | None = None,
occurrences: int = 1,
attributes: dict[str, str] | None = None,
event_time: datetime | None = None,
) -> KGEntityExtractionStaging:
"""Add or update a new staging entity to the database.
Args:
db_session: SQLAlchemy session
name: Name of the entity
entity_type: Type of the entity (must match an existing KGEntityType)
document_id: ID of the document the entity belongs to
occurrences: Number of times this entity has been found
attributes: Attributes of the entity
event_time: Time the entity was added to the database
Returns:
KGEntityExtractionStaging: The created entity
"""
entity_type = entity_type.upper()
name = name.title()
id_name = make_entity_id(entity_type, name)
attributes = attributes or {}
entity_type_split = entity_type.split("-")
entity_class, entity_subtype = (
entity_type_split if len(entity_type_split) == 2 else (entity_type, None)
)
entity_key = attributes.get("key")
entity_parent = attributes.get("parent")
keep_attributes = {
attr_key: attr_val
for attr_key, attr_val in attributes.items()
if not (
(attr_key in ("key", "parent") and entity_class)
or attr_key in ("object_type", "issuetype")
or "_email" in attr_key
)
}
# Create new entity
stmt = (
pg_insert(KGEntityExtractionStaging)
.values(
id_name=id_name,
name=name,
entity_type_id_name=entity_type,
entity_class=entity_class,
entity_subtype=entity_subtype,
entity_key=entity_key,
parent_key=entity_parent,
document_id=document_id,
occurrences=occurrences,
attributes=keep_attributes,
event_time=event_time,
)
.on_conflict_do_update(
index_elements=["id_name"],
set_=dict(
occurrences=KGEntityExtractionStaging.occurrences + occurrences,
),
)
.returning(KGEntityExtractionStaging)
)
result = db_session.execute(stmt).scalar()
if result is None:
raise RuntimeError(
f"Failed to create or increment staging entity with id_name: {id_name}"
)
# Update the document's kg_stage if document_id is provided
if document_id is not None:
db_session.query(Document).filter(Document.id == document_id).update(
{
"kg_stage": KGStage.EXTRACTED,
"kg_processing_time": datetime.now(timezone.utc),
}
)
db_session.flush()
return result
def transfer_entity(
db_session: Session,
entity: KGEntityExtractionStaging,
) -> KGEntity:
"""Transfer an entity from the extraction staging table to the normalized table.
Args:
db_session: SQLAlchemy session
entity: Entity to transfer
Returns:
KGEntity: The transferred entity
"""
# Create the transferred entity
stmt = (
pg_insert(KGEntity)
.values(
id_name=make_entity_id(entity.entity_type_id_name, uuid.uuid4().hex[:20]),
name=entity.name.casefold(),
entity_class=entity.entity_class,
entity_subtype=entity.entity_subtype,
entity_key=entity.entity_key,
alternative_names=entity.alternative_names or [],
entity_type_id_name=entity.entity_type_id_name,
document_id=entity.document_id,
occurrences=entity.occurrences,
attributes=entity.attributes,
event_time=entity.event_time,
)
.on_conflict_do_update(
index_elements=["name", "entity_type_id_name", "document_id"],
set_=dict(
occurrences=KGEntity.occurrences + entity.occurrences,
),
)
.returning(KGEntity)
)
new_entity = db_session.execute(stmt).scalar()
if new_entity is None:
raise RuntimeError(f"Failed to transfer entity with id_name: {entity.id_name}")
# Update the document's kg_stage if document_id is provided
if entity.document_id is not None:
dbdocument.update_document_kg_info(
db_session,
document_id=entity.document_id,
kg_stage=KGStage.NORMALIZED,
)
# Update transferred
db_session.query(KGEntityExtractionStaging).filter(
KGEntityExtractionStaging.id_name == entity.id_name
).update({"transferred_id_name": new_entity.id_name})
db_session.flush()
return new_entity
def merge_entities(
db_session: Session, parent: KGEntity, child: KGEntityExtractionStaging
) -> KGEntity:
"""Merge an entity from the extraction staging table into
an existing entity in the normalized table.
Args:
db_session: SQLAlchemy session
parent: Parent entity to merge into
child: Child staging entity to merge
Returns:
KGEntity: The merged entity
"""
# check we're not merging two entities with different document_ids
if (
parent.document_id is not None
and child.document_id is not None
and parent.document_id != child.document_id
):
raise ValueError(
"Overwriting the document_id of an entity with a document_id already is not allowed"
)
# update the parent entity (only document_id, alternative_names, occurrences)
setting_doc = parent.document_id is None and child.document_id is not None
document_id = child.document_id if setting_doc else parent.document_id
alternative_names = set(parent.alternative_names or [])
alternative_names.update(child.alternative_names or [])
alternative_names.add(child.name.lower())
alternative_names.discard(parent.name)
stmt = (
update(KGEntity)
.where(KGEntity.id_name == parent.id_name)
.values(
document_id=document_id,
alternative_names=list(alternative_names),
occurrences=parent.occurrences + child.occurrences,
)
.returning(KGEntity)
)
result = db_session.execute(stmt).scalar()
if result is None:
raise RuntimeError(f"Failed to merge entities with id_name: {parent.id_name}")
# Update the document's kg_stage if document_id is set
if setting_doc and child.document_id is not None:
dbdocument.update_document_kg_info(
db_session,
document_id=child.document_id,
kg_stage=KGStage.NORMALIZED,
)
# Update transferred
db_session.query(KGEntityExtractionStaging).filter(
KGEntityExtractionStaging.id_name == child.id_name
).update({"transferred_id_name": parent.id_name})
db_session.flush()
return result
def get_kg_entity_by_document(db: Session, document_id: str) -> KGEntity | None:
"""
Check if a document_id exists in the kg_entities table and return its id_name if found.
Args:
db: SQLAlchemy database session
document_id: The document ID to search for
Returns:
The id_name of the matching KGEntity if found, None otherwise
"""
query = select(KGEntity).where(KGEntity.document_id == document_id)
result = db.execute(query).scalar()
return result
def get_grounded_entities_by_types(
db_session: Session, entity_types: List[str], grounding: KGGroundingType
) -> List[KGEntity]:
"""Get all entities matching an entity_type.
Args:
db_session: SQLAlchemy session
entity_types: List of entity types to filter by
Returns:
List of KGEntity objects belonging to the specified entity types
"""
return (
db_session.query(KGEntity)
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
.filter(KGEntity.entity_type_id_name.in_(entity_types))
.filter(KGEntityType.grounding == grounding)
.all()
)
def get_document_id_for_entity(db_session: Session, entity_id_name: str) -> str | None:
"""Get the document ID associated with an entity.
Args:
db_session: SQLAlchemy database session
entity_id_name: The entity id_name to look up
Returns:
The document ID if found, None otherwise
"""
entity = (
db_session.query(KGEntity).filter(KGEntity.id_name == entity_id_name).first()
)
return entity.document_id if entity else None
def delete_from_kg_entities_extraction_staging__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""Delete entities from the extraction staging table."""
db_session.query(KGEntityExtractionStaging).filter(
KGEntityExtractionStaging.document_id.in_(document_ids)
).delete(synchronize_session=False)
def delete_from_kg_entities__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""Delete entities from the normalized table."""
db_session.query(KGEntity).filter(KGEntity.document_id.in_(document_ids)).delete(
synchronize_session=False
)
def get_entity_name(db_session: Session, entity_id_name: str) -> str | None:
"""Get the name of an entity."""
entity = (
db_session.query(KGEntity).filter(KGEntity.id_name == entity_id_name).first()
)
return entity.name if entity else None

View File

@@ -0,0 +1,257 @@
from typing import List
from sqlalchemy.orm import Session
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.models import KGEntityType
from onyx.kg.kg_default_entity_definitions import KGDefaultAccountEmployeeDefinitions
from onyx.kg.kg_default_entity_definitions import (
KGDefaultPrimaryGroundedEntityDefinitions,
)
from onyx.kg.models import KGGroundingType
def get_determined_grounded_entity_types(db_session: Session) -> List[KGEntityType]:
"""Get all entity types that have non-null entity_values.
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntityType objects that have entity_values defined
"""
return (
db_session.query(KGEntityType)
.filter(KGEntityType.entity_values.isnot(None))
.all()
)
def get_grounded_entity_types(db_session: Session) -> List[KGEntityType]:
"""Get all entity types that have grounding = GROUNDED.
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntityType objects that have grounding = GROUNDED
"""
return (
db_session.query(KGEntityType)
.filter(KGEntityType.grounding == KGGroundingType.GROUNDED)
.all()
)
def get_entity_types_with_grounded_source_name(
db_session: Session,
) -> List[KGEntityType]:
"""Get all entity types that have non-null grounded_source_name.
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntityType objects that have grounded_source_name defined
"""
return (
db_session.query(KGEntityType)
.filter(KGEntityType.grounded_source_name.isnot(None))
.all()
)
def get_entity_type_by_grounded_source_name(
db_session: Session, grounded_source_name: KGGroundingType
) -> KGEntityType | None:
"""Get an entity type by its grounded_source_name and return it as a dictionary.
Args:
db_session: SQLAlchemy session
grounded_source_name: The grounded_source_name of the entity to retrieve
Returns:
Dictionary containing the entity's data with column names as keys,
or None if the entity is not found
"""
entity_type = (
db_session.query(KGEntityType)
.filter(KGEntityType.grounded_source_name == grounded_source_name)
.first()
)
if entity_type is None:
return None
return entity_type
def get_entity_types(
db_session: Session,
active: bool | None = True,
) -> list[KGEntityType]:
# Query the database for all distinct entity types
if active is None:
return db_session.query(KGEntityType).order_by(KGEntityType.id_name).all()
else:
return (
db_session.query(KGEntityType)
.filter(KGEntityType.active == active)
.order_by(KGEntityType.id_name)
.all()
)
def populate_default_primary_grounded_entity_type_information(
db_session: Session,
) -> None:
"""Populate the entity type information for the KG.
Args:
db_session: SQLAlchemy session
"""
# get kg config information
kg_config_settings = get_kg_config_settings(db_session)
if not kg_config_settings.KG_ENABLED:
raise ValueError("KG is not enabled")
if not kg_config_settings.KG_VENDOR:
raise ValueError("KG_VENDOR is not set")
if not kg_config_settings.KG_VENDOR_DOMAINS:
raise ValueError("KG_VENDOR_DOMAINS is not set")
# Get all existing entity types
existing_entity_types = {et.id_name for et in db_session.query(KGEntityType).all()}
# Create an instance of the default definitions
default_definitions = KGDefaultPrimaryGroundedEntityDefinitions()
# Iterate over all attributes in the default definitions
for id_name, definition in default_definitions.model_dump().items():
# Skip if this entity type already exists
if id_name in existing_entity_types:
continue
# Create new entity type
description = definition["description"].replace(
"---vendor_name---", kg_config_settings.KG_VENDOR
)
new_entity_type = KGEntityType(
id_name=id_name,
description=description,
grounding=definition["grounding"],
grounded_source_name=definition["grounded_source_name"],
active=False,
)
# Add to session
db_session.add(new_entity_type)
# Commit changes
db_session.flush()
def populate_default_employee_account_information(db_session: Session) -> None:
"""Populate the entity type information for the KG.
Args:
db_session: SQLAlchemy session
"""
# get kg config information
kg_config_settings = get_kg_config_settings(db_session)
if not kg_config_settings.KG_ENABLED:
raise ValueError("KG is not enabled")
if not kg_config_settings.KG_VENDOR:
raise ValueError("KG_VENDOR is not set")
if not kg_config_settings.KG_VENDOR_DOMAINS:
raise ValueError("KG_VENDOR_DOMAINS is not set")
# Get all existing entity types
existing_entity_types = {et.id_name for et in db_session.query(KGEntityType).all()}
# Create an instance of the default definitions
default_definitions = KGDefaultAccountEmployeeDefinitions()
# Iterate over all attributes in the default definitions
for id_name, definition in default_definitions.model_dump().items():
# Skip if this entity type already exists
if id_name in existing_entity_types:
continue
# Create new entity type
description = definition["description"].replace(
"---vendor_name---", kg_config_settings.KG_VENDOR
)
new_entity_type = KGEntityType(
id_name=id_name,
description=description,
grounding=definition["grounding"],
grounded_source_name=definition["grounded_source_name"],
active=definition["active"],
)
# Add to session
db_session.add(new_entity_type)
# Commit changes
db_session.flush()
def get_grounded_entity_types_with_null_grounded_source(
db_session: Session,
) -> List[KGEntityType]:
"""Get all entity types that have null grounded_source_name and grounding = GROUNDED.
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntityType objects that have null grounded_source_name and grounding = GROUNDED
"""
return (
db_session.query(KGEntityType)
.filter(KGEntityType.grounded_source_name.is_(None))
.filter(KGEntityType.grounding == KGGroundingType.GROUNDED)
.all()
)
def get_entity_types_by_grounding(
db_session: Session,
grounding: KGGroundingType,
) -> List[KGEntityType]:
"""Get all entity types that have a specific grounding.
Args:
db_session: SQLAlchemy session
grounding: The grounding type to filter by
Returns:
List of KGEntityType objects that have the specified grounding
"""
return (
db_session.query(KGEntityType).filter(KGEntityType.grounding == grounding).all()
)
def get_grounded_source_name(db_session: Session, entity_type: str) -> str | None:
"""
Get the grounded source name for an entity type.
"""
result = (
db_session.query(KGEntityType)
.filter(KGEntityType.id_name == entity_type)
.first()
)
if result is None:
return None
return result.grounded_source_name

View File

@@ -0,0 +1,141 @@
from datetime import datetime
from enum import Enum
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session
from onyx.db.models import KGConfig
from onyx.kg.models import KGConfigSettings
from onyx.kg.models import KGConfigVars
class KGProcessingType(Enum):
EXTRACTION = "extraction"
CLUSTERING = "clustering"
def get_kg_enablement(db_session: Session) -> bool:
check = (
db_session.query(KGConfig.kg_variable_values)
.filter(
KGConfig.kg_variable_name == "KG_ENABLED"
and KGConfig.kg_variable_values == ["true"]
)
.first()
)
return check is not None
def get_kg_config_settings(db_session: Session) -> KGConfigSettings:
results = db_session.query(KGConfig).all()
kg_config_settings = KGConfigSettings()
for result in results:
if result.kg_variable_name == KGConfigVars.KG_ENABLED:
kg_config_settings.KG_ENABLED = result.kg_variable_values[0] == "true"
elif result.kg_variable_name == KGConfigVars.KG_VENDOR:
if len(result.kg_variable_values) > 0:
kg_config_settings.KG_VENDOR = result.kg_variable_values[0]
else:
kg_config_settings.KG_VENDOR = None
elif result.kg_variable_name == KGConfigVars.KG_VENDOR_DOMAINS:
kg_config_settings.KG_VENDOR_DOMAINS = result.kg_variable_values
elif result.kg_variable_name == KGConfigVars.KG_IGNORE_EMAIL_DOMAINS:
kg_config_settings.KG_IGNORE_EMAIL_DOMAINS = result.kg_variable_values
elif result.kg_variable_name == KGConfigVars.KG_COVERAGE_START:
kg_coverage_start_str = result.kg_variable_values[0] or "1970-01-01"
kg_config_settings.KG_COVERAGE_START = datetime.strptime(
kg_coverage_start_str, "%Y-%m-%d"
)
elif result.kg_variable_name == KGConfigVars.KG_MAX_COVERAGE_DAYS:
kg_max_coverage_days_str = result.kg_variable_values[0]
if not kg_max_coverage_days_str.isdigit():
raise ValueError(
f"KG_MAX_COVERAGE_DAYS is not a number: {kg_max_coverage_days_str}"
)
kg_config_settings.KG_MAX_COVERAGE_DAYS = max(
0, int(kg_max_coverage_days_str)
)
elif result.kg_variable_name == KGConfigVars.KG_EXTRACTION_IN_PROGRESS:
kg_config_settings.KG_EXTRACTION_IN_PROGRESS = (
result.kg_variable_values[0] == "true"
)
elif result.kg_variable_name == KGConfigVars.KG_CLUSTERING_IN_PROGRESS:
kg_config_settings.KG_CLUSTERING_IN_PROGRESS = (
result.kg_variable_values[0] == "true"
)
elif result.kg_variable_name == KGConfigVars.KG_MAX_PARENT_RECURSION_DEPTH:
kg_max_parent_recursion_depth_str = result.kg_variable_values[0]
if not kg_max_parent_recursion_depth_str.isdigit():
raise ValueError(
f"KG_MAX_PARENT_RECURSION_DEPTH is not a number: {kg_max_parent_recursion_depth_str}"
)
kg_config_settings.KG_MAX_PARENT_RECURSION_DEPTH = max(
0, int(kg_max_parent_recursion_depth_str)
)
elif result.kg_variable_name == KGConfigVars.KG_EXPOSED:
kg_config_settings.KG_EXPOSED = result.kg_variable_values[0] == "true"
return kg_config_settings
def set_kg_processing_in_progress_status(
db_session: Session, processing_type: KGProcessingType, in_progress: bool
) -> None:
"""
Set the KG_EXTRACTION_IN_PROGRESS or KG_CLUSTERING_IN_PROGRESS configuration values.
Args:
db_session: The database session to use
in_progress: Whether KG processing is in progress (True) or not (False)
"""
# Convert boolean to string and wrap in list as required by the model
value = [str(in_progress).lower()]
kg_variable_name = "KG_EXTRACTION_IN_PROGRESS" # Default value
if processing_type == KGProcessingType.CLUSTERING:
kg_variable_name = "KG_CLUSTERING_IN_PROGRESS"
# Use PostgreSQL's upsert functionality
stmt = (
pg_insert(KGConfig)
.values(kg_variable_name=str(kg_variable_name), kg_variable_values=value)
.on_conflict_do_update(
index_elements=["kg_variable_name"], set_=dict(kg_variable_values=value)
)
)
db_session.execute(stmt)
def get_kg_processing_in_progress_status(
db_session: Session, processing_type: KGProcessingType
) -> bool:
"""
Get the current KG_EXTRACTION_IN_PROGRESS or KG_CLUSTERING_IN_PROGRESS configuration value.
Args:
db_session: The database session to use
Returns:
bool: True if KG processing is in progress, False otherwise
"""
kg_variable_name = "KG_EXTRACTION_IN_PROGRESS" # Default value
if processing_type == KGProcessingType.CLUSTERING:
kg_variable_name = "KG_CLUSTERING_IN_PROGRESS"
config = (
db_session.query(KGConfig)
.filter(KGConfig.kg_variable_name == kg_variable_name)
.first()
)
if not config:
return False
return config.kg_variable_values[0] == "true"

View File

@@ -0,0 +1,167 @@
from sqlalchemy import text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DB_READONLY_USER
Base = declarative_base()
def get_user_view_names(user_email: str) -> tuple[str, str]:
user_email_cleaned = user_email.replace("@", "_").replace(".", "_")
return (
f"allowed_docs_{user_email_cleaned}",
f"kg_relationships_with_access_{user_email_cleaned}",
)
# First, create the view definition
def create_views(
db_session: Session,
user_email: str,
allowed_docs_view_name: str = "allowed_docs",
kg_relationships_view_name: str = "kg_relationships_with_access",
) -> None:
# Create ALLOWED_DOCS view
allowed_docs_view = text(
f"""
CREATE OR REPLACE VIEW {allowed_docs_view_name} AS
WITH kg_used_docs AS (
SELECT document_id as kg_used_doc_id
FROM kg_entity d
WHERE document_id IS NOT NULL
),
public_docs AS (
SELECT d.id as allowed_doc_id
FROM document d
INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id
WHERE d.is_public
),
user_owned_docs AS (
SELECT d.id as allowed_doc_id
FROM document_by_connector_credential_pair d
JOIN credential c ON d.credential_id = c.id
JOIN connector_credential_pair ccp ON
d.connector_id = ccp.connector_id AND
d.credential_id = ccp.credential_id
JOIN "user" u ON c.user_id = u.id
INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id
WHERE ccp.status != 'DELETING'
AND ccp.access_type != 'SYNC'
AND u.email = :user_email
),
user_group_accessible_docs AS (
SELECT d.id as allowed_doc_id
FROM document_by_connector_credential_pair d
JOIN connector_credential_pair ccp ON
d.connector_id = ccp.connector_id AND
d.credential_id = ccp.credential_id
JOIN user_group__connector_credential_pair ugccp ON
ccp.id = ugccp.cc_pair_id
JOIN user__user_group uug ON
uug.user_group_id = ugccp.user_group_id
JOIN "user" u ON uug.user_id = u.id
INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id
WHERE kud.kg_used_doc_id IS NOT NULL
AND ccp.status != 'DELETING'
AND ccp.access_type != 'SYNC'
AND u.email = :user_email
),
external_user_docs AS (
SELECT d.id as allowed_doc_id
FROM document d
INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id
WHERE kud.kg_used_doc_id IS NOT NULL
AND :user_email = ANY(external_user_emails)
),
external_group_docs AS (
SELECT d.id as allowed_doc_id
FROM document d
INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id
JOIN user__external_user_group_id ueg ON ueg.external_user_group_id = ANY(d.external_user_group_ids)
JOIN "user" u ON ueg.user_id = u.id
WHERE kud.kg_used_doc_id IS NOT NULL
AND u.email = :user_email
)
SELECT DISTINCT allowed_doc_id FROM (
SELECT allowed_doc_id FROM public_docs
UNION
SELECT allowed_doc_id FROM user_owned_docs
UNION
SELECT allowed_doc_id FROM user_group_accessible_docs
UNION
SELECT allowed_doc_id FROM external_user_docs
UNION
SELECT allowed_doc_id FROM external_group_docs
) combined_docs
"""
).bindparams(user_email=user_email)
# Create the main view that uses ALLOWED_DOCS
kg_relationships_view = text(
f"""
CREATE OR REPLACE VIEW {kg_relationships_view_name} AS
SELECT kgr.id_name as relationship,
kgr.source_node as source_entity,
kgr.target_node as target_entity,
kgr.source_node_type as source_entity_type,
kgr.target_node_type as target_entity_type,
kgr.type as relationship_description,
kgr.relationship_type_id_name as relationship_type,
kgr.source_document as source_document,
d.doc_updated_at as source_date,
se.attributes as source_entity_attributes,
te.attributes as target_entity_attributes
FROM kg_relationship kgr
INNER JOIN {allowed_docs_view_name} AD on AD.allowed_doc_id = kgr.source_document
JOIN document d on d.id = kgr.source_document
JOIN kg_entity se on se.id_name = kgr.source_node
JOIN kg_entity te on te.id_name = kgr.target_node
"""
)
# Execute the views using the session
db_session.execute(allowed_docs_view)
db_session.execute(kg_relationships_view)
# Grant permissions on view to readonly user
db_session.execute(
text(f"GRANT SELECT ON {kg_relationships_view_name} TO {DB_READONLY_USER}")
)
db_session.commit()
return None
def drop_views(
db_session: Session,
allowed_docs_view_name: str = "allowed_docs",
kg_relationships_view_name: str = "kg_relationships_with_access",
) -> None:
"""
Drops the temporary views created by create_views.
Args:
db_session: SQLAlchemy session
allowed_docs_view_name: Name of the allowed_docs view
kg_relationships_view_name: Name of the kg_relationships view
"""
# First revoke access from the readonly user
revoke_kg_relationships = text(
f"REVOKE SELECT ON {kg_relationships_view_name} FROM {DB_READONLY_USER}"
)
db_session.execute(revoke_kg_relationships)
# Drop the views in reverse order of creation to handle dependencies
drop_kg_relationships = text(f"DROP VIEW IF EXISTS {kg_relationships_view_name}")
drop_allowed_docs = text(f"DROP VIEW IF EXISTS {allowed_docs_view_name}")
db_session.execute(drop_kg_relationships)
db_session.execute(drop_allowed_docs)
db_session.commit()
return None

View File

@@ -39,6 +39,7 @@ from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.types import LargeBinary
from sqlalchemy.types import TypeDecorator
from sqlalchemy import PrimaryKeyConstraint
from onyx.auth.schemas import UserRole
from onyx.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
@@ -69,6 +70,7 @@ from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.context.search.enums import RecencyBiasSetting
from onyx.kg.models import KGStage
from onyx.utils.encryption import decrypt_bytes_to_string
from onyx.utils.encryption import encrypt_string_to_bytes
from onyx.utils.headers import HeaderItemDict
@@ -586,6 +588,17 @@ class Document(Base):
)
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
# tables for the knowledge graph data
kg_stage: Mapped[KGStage] = mapped_column(
Enum(KGStage, native_enum=False),
comment="Status of knowledge graph extraction for this document",
index=True,
)
kg_processing_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
"DocumentRetrievalFeedback", back_populates="document"
)
@@ -604,12 +617,639 @@ class Document(Base):
)
class KGConfig(Base):
__tablename__ = "kg_config"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
kg_variable_name: Mapped[str] = mapped_column(NullFilteredString, nullable=False)
kg_variable_values: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
__table_args__ = (
UniqueConstraint("kg_variable_name", name="uq_kg_config_variable_name"),
)
class KGEntityType(Base):
__tablename__ = "kg_entity_type"
# Primary identifier
id_name: Mapped[str] = mapped_column(
String, primary_key=True, nullable=False, index=True
)
description: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
grounding: Mapped[str] = mapped_column(
NullFilteredString, nullable=False, index=False
)
attributes: Mapped[str] = mapped_column(
postgresql.JSONB,
nullable=True,
default=dict,
server_default="{}",
comment="Filtering based on document attribute",
)
occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
deep_extraction: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
grounded_source_name: Mapped[str] = mapped_column(
NullFilteredString, nullable=False, index=False
)
entity_values: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=True, default=None
)
clustering: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Clustering information for this entity type",
)
class KGRelationshipType(Base):
__tablename__ = "kg_relationship_type"
# Primary identifier
id_name: Mapped[str] = mapped_column(
NullFilteredString,
primary_key=True,
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
source_entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
target_entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
definition: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
comment="Whether this relationship type represents a definition",
)
clustering: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Clustering information for this relationship type",
)
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
# Relationships to EntityType
source_type: Mapped["KGEntityType"] = relationship(
"KGEntityType",
foreign_keys=[source_entity_type_id_name],
backref="source_relationship_type",
)
target_type: Mapped["KGEntityType"] = relationship(
"KGEntityType",
foreign_keys=[target_entity_type_id_name],
backref="target_relationship_type",
)
class KGRelationshipTypeExtractionStaging(Base):
__tablename__ = "kg_relationship_type_extraction_staging"
# Primary identifier
id_name: Mapped[str] = mapped_column(
NullFilteredString,
primary_key=True,
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
source_entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
target_entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
definition: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
comment="Whether this relationship type represents a definition",
)
clustering: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Clustering information for this relationship type",
)
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
transferred: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
)
# Tracking fields
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
# Relationships to EntityType
source_type: Mapped["KGEntityType"] = relationship(
"KGEntityType",
foreign_keys=[source_entity_type_id_name],
backref="source_relationship_type_staging",
)
target_type: Mapped["KGEntityType"] = relationship(
"KGEntityType",
foreign_keys=[target_entity_type_id_name],
backref="target_relationship_type_staging",
)
class KGEntity(Base):
__tablename__ = "kg_entity"
# Primary identifier
id_name: Mapped[str] = mapped_column(
NullFilteredString, primary_key=True, index=True
)
# Basic entity information
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
entity_class: Mapped[str] = mapped_column(
NullFilteredString, nullable=True, index=True
)
entity_key: Mapped[str] = mapped_column(
NullFilteredString, nullable=True, index=True
)
entity_subtype: Mapped[str] = mapped_column(
NullFilteredString, nullable=True, index=True
)
name_trigrams: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String(3)),
nullable=True,
)
attributes: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Attributes for this entity",
)
document_id: Mapped[str | None] = mapped_column(
NullFilteredString, nullable=True, index=True
)
alternative_names: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Reference to KGEntityType
entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
# Relationship to KGEntityType
entity_type: Mapped["KGEntityType"] = relationship("KGEntityType", backref="entity")
description: Mapped[str | None] = mapped_column(String, nullable=True)
keywords: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
# Access control
acl: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Boosts - using JSON for flexibility
boosts: Mapped[dict] = mapped_column(postgresql.JSONB, nullable=False, default=dict)
event_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
comment="Time of the event being processed",
)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
__table_args__ = (
# Fixed column names in indexes
Index("ix_entity_type_acl", entity_type_id_name, acl),
Index("ix_entity_name_search", name, entity_type_id_name),
)
class KGEntityExtractionStaging(Base):
__tablename__ = "kg_entity_extraction_staging"
# Primary identifier
id_name: Mapped[str] = mapped_column(
NullFilteredString,
primary_key=True,
nullable=False,
index=True,
)
# Basic entity information
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
attributes: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Attributes for this entity",
)
document_id: Mapped[str | None] = mapped_column(
NullFilteredString, nullable=True, index=True
)
alternative_names: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Reference to KGEntityType
entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
# Relationship to KGEntityType
entity_type: Mapped["KGEntityType"] = relationship(
"KGEntityType", backref="entity_staging"
)
description: Mapped[str | None] = mapped_column(String, nullable=True)
keywords: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
# Access control
acl: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Boosts - using JSON for flexibility
boosts: Mapped[dict] = mapped_column(postgresql.JSONB, nullable=False, default=dict)
transferred_id_name: Mapped[str | None] = mapped_column(
NullFilteredString,
nullable=True,
)
# Basic entity information
entity_class: Mapped[str] = mapped_column(
NullFilteredString, nullable=True, index=True
)
# Basic entity information
entity_key: Mapped[str] = mapped_column(
NullFilteredString, nullable=True, index=True
)
entity_subtype: Mapped[str] = mapped_column(
NullFilteredString, nullable=True, index=True
)
# Basic entity information
parent_key: Mapped[str] = mapped_column(
NullFilteredString, nullable=True, index=True
)
event_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
comment="Time of the event being processed",
)
# Tracking fields
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
__table_args__ = (
# Fixed column names in indexes
Index("ix_entity_type_acl", entity_type_id_name, acl),
Index("ix_entity_name_search", name, entity_type_id_name),
)
class KGRelationship(Base):
__tablename__ = "kg_relationship"
# Primary identifier - now part of composite key
id_name: Mapped[str] = mapped_column(
NullFilteredString,
nullable=False,
index=True,
)
source_document: Mapped[str | None] = mapped_column(
NullFilteredString, ForeignKey("document.id"), nullable=True, index=True
)
# Source and target nodes (foreign keys to Entity table)
source_node: Mapped[str] = mapped_column(
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
)
target_node: Mapped[str] = mapped_column(
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
)
source_node_type: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
target_node_type: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
# Relationship type
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
# Add new relationship type reference
relationship_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_relationship_type.id_name"),
nullable=False,
index=True,
)
# Add the SQLAlchemy relationship property
relationship_type: Mapped["KGRelationshipType"] = relationship(
"KGRelationshipType", backref="relationship"
)
occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
# Relationships to Entity table
source: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[source_node])
target: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[target_node])
document: Mapped["Document"] = relationship(
"Document", foreign_keys=[source_document]
)
__table_args__ = (
# Composite primary key
PrimaryKeyConstraint("id_name", "source_document"),
# Index for querying relationships by type
Index("ix_kg_relationship_type", type),
# Composite index for source/target queries
Index("ix_kg_relationship_nodes", source_node, target_node),
# Ensure unique relationships between nodes of a specific type
UniqueConstraint(
"source_node",
"target_node",
"type",
name="uq_kg_relationship_source_target_type",
),
)
class KGRelationshipExtractionStaging(Base):
__tablename__ = "kg_relationship_extraction_staging"
# Primary identifier - now part of composite key
id_name: Mapped[str] = mapped_column(
NullFilteredString,
nullable=False,
index=True,
)
source_document: Mapped[str | None] = mapped_column(
NullFilteredString, ForeignKey("document.id"), nullable=True, index=True
)
# Source and target nodes (foreign keys to Entity table)
source_node: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_extraction_staging.id_name"),
nullable=False,
index=True,
)
target_node: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_extraction_staging.id_name"),
nullable=False,
index=True,
)
source_node_type: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
target_node_type: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
# Relationship type
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
# Add new relationship type reference
relationship_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_relationship_type_extraction_staging.id_name"),
nullable=False,
index=True,
)
# Add the SQLAlchemy relationship property
relationship_type: Mapped["KGRelationshipTypeExtractionStaging"] = relationship(
"KGRelationshipTypeExtractionStaging", backref="relationship_staging"
)
occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
transferred: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
)
# Tracking fields
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
# Relationships to Entity table
source: Mapped["KGEntityExtractionStaging"] = relationship(
"KGEntityExtractionStaging", foreign_keys=[source_node]
)
target: Mapped["KGEntityExtractionStaging"] = relationship(
"KGEntityExtractionStaging", foreign_keys=[target_node]
)
document: Mapped["Document"] = relationship(
"Document", foreign_keys=[source_document]
)
__table_args__ = (
# Composite primary key
PrimaryKeyConstraint("id_name", "source_document"),
# Index for querying relationships by type
Index("ix_kg_relationship_type", type),
# Composite index for source/target queries
Index("ix_kg_relationship_nodes", source_node, target_node),
# Ensure unique relationships between nodes of a specific type
UniqueConstraint(
"source_node",
"target_node",
"type",
name="uq_kg_relationship_source_target_type",
),
)
class KGTerm(Base):
__tablename__ = "kg_term"
# Make id_term the primary key
id_term: Mapped[str] = mapped_column(
NullFilteredString, primary_key=True, nullable=False, index=True
)
# List of entity types this term applies to
entity_types: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
__table_args__ = (
# Index for searching terms with specific entity types
Index("ix_search_term_entities", entity_types),
# Index for term lookups
Index("ix_search_term_term", id_term),
)
class ChunkStats(Base):
__tablename__ = "chunk_stats"
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
# this should correspond to the ID of the document
# (as is passed around in Onyx)
# (as is passed around in Onyx)x
id: Mapped[str] = mapped_column(
NullFilteredString,
primary_key=True,
@@ -691,6 +1331,16 @@ class Connector(Base):
indexing_start: Mapped[datetime.datetime | None] = mapped_column(
DateTime, nullable=True
)
kg_processing_enabled: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
comment="Whether this connector should extract knowledge graph entities",
)
kg_coverage_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
time_created: Mapped[datetime.datetime] = mapped_column(

View File

@@ -0,0 +1,590 @@
from typing import List
from sqlalchemy import or_
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session
import onyx.db.document as dbdocument
from onyx.db.models import KGEntity
from onyx.db.models import KGEntityExtractionStaging
from onyx.db.models import KGRelationship
from onyx.db.models import KGRelationshipExtractionStaging
from onyx.db.models import KGRelationshipType
from onyx.db.models import KGRelationshipTypeExtractionStaging
from onyx.db.models import KGStage
from onyx.kg.utils.formatting_utils import extract_relationship_type_id
from onyx.kg.utils.formatting_utils import format_relationship_id
from onyx.kg.utils.formatting_utils import get_entity_type
from onyx.kg.utils.formatting_utils import make_relationship_id
from onyx.kg.utils.formatting_utils import make_relationship_type_id
from onyx.kg.utils.formatting_utils import split_relationship_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
def upsert_staging_relationship(
db_session: Session,
relationship_id_name: str,
source_document_id: str,
occurrences: int = 1,
) -> KGRelationshipExtractionStaging:
"""
Add or update a new staging relationship to the database.
Args:
db_session: SQLAlchemy database session
relationship_id_name: The ID name of the relationship in format "source__relationship__target"
source_document_id: ID of the source document
occurrences: Number of times this relationship has been found
Returns:
The created or updated KGRelationshipExtractionStaging object
Raises:
sqlalchemy.exc.IntegrityError: If there's an error with the database operation
"""
# Generate a unique ID for the relationship
relationship_id_name = format_relationship_id(relationship_id_name)
(
source_entity_id_name,
relationship_string,
target_entity_id_name,
) = split_relationship_id(relationship_id_name)
source_entity_type = get_entity_type(source_entity_id_name)
target_entity_type = get_entity_type(target_entity_id_name)
relationship_type = extract_relationship_type_id(relationship_id_name)
# Insert the new relationship
stmt = (
postgresql.insert(KGRelationshipExtractionStaging)
.values(
{
"id_name": relationship_id_name,
"source_node": source_entity_id_name,
"target_node": target_entity_id_name,
"source_node_type": source_entity_type,
"target_node_type": target_entity_type,
"type": relationship_string.lower(),
"relationship_type_id_name": relationship_type,
"source_document": source_document_id,
"occurrences": occurrences,
}
)
.on_conflict_do_update(
index_elements=["id_name", "source_document"],
set_=dict(
occurrences=KGRelationshipExtractionStaging.occurrences + occurrences,
),
)
.returning(KGRelationshipExtractionStaging)
)
result = db_session.execute(stmt).scalar()
if result is None:
raise RuntimeError(
f"Failed to create or increment staging relationship with id_name: {relationship_id_name}"
)
# Update the document's kg_stage if source_document is provided
if source_document_id is not None:
dbdocument.update_document_kg_info(
db_session,
document_id=source_document_id,
kg_stage=KGStage.EXTRACTED,
)
db_session.flush() # Flush to get any DB errors early
return result
def transfer_relationship(
db_session: Session,
relationship: KGRelationshipExtractionStaging,
entity_translations: dict[str, str],
) -> KGRelationship:
"""
Transfer a relationship from the staging table to the normalized table.
"""
# Translate the source and target nodes
source_node = entity_translations.get(
relationship.source_node, relationship.source_node
)
target_node = entity_translations.get(
relationship.target_node, relationship.target_node
)
relationship_id_name = make_relationship_id(
source_node, relationship.type, target_node
)
# Create the transferred relationship
stmt = (
pg_insert(KGRelationship)
.values(
id_name=relationship_id_name,
source_node=source_node,
target_node=target_node,
source_node_type=relationship.source_node_type,
target_node_type=relationship.target_node_type,
type=relationship.type,
relationship_type_id_name=relationship.relationship_type_id_name,
source_document=relationship.source_document,
occurrences=relationship.occurrences,
)
.on_conflict_do_update(
index_elements=["id_name", "source_document"],
set_=dict(
occurrences=KGRelationship.occurrences + relationship.occurrences,
),
)
.returning(KGRelationship)
)
new_relationship = db_session.execute(stmt).scalar()
if new_relationship is None:
raise RuntimeError(
f"Failed to transfer relationship with id_name: {relationship.id_name}"
)
# Update transferred
db_session.query(KGRelationshipExtractionStaging).filter(
KGRelationshipExtractionStaging.id_name == relationship.id_name,
KGRelationshipExtractionStaging.source_document == relationship.source_document,
).update({"transferred": True})
db_session.flush()
return new_relationship
def upsert_staging_relationship_type(
db_session: Session,
source_entity_type: str,
relationship_type: str,
target_entity_type: str,
definition: bool = False,
extraction_count: int = 1,
) -> KGRelationshipTypeExtractionStaging:
"""
Add a new relationship type to the database.
Args:
db_session: SQLAlchemy session
source_entity_type: Type of the source entity
relationship_type: Type of relationship
target_entity_type: Type of the target entity
definition: Whether this relationship type represents a definition (default False)
Returns:
The created KGRelationshipTypeExtractionStaging object
"""
id_name = make_relationship_type_id(
source_entity_type, relationship_type, target_entity_type
)
# Create new relationship type
stmt = (
postgresql.insert(KGRelationshipTypeExtractionStaging)
.values(
{
"id_name": id_name,
"name": relationship_type,
"source_entity_type_id_name": source_entity_type.upper(),
"target_entity_type_id_name": target_entity_type.upper(),
"definition": definition,
"occurrences": extraction_count,
"type": relationship_type, # Using the relationship_type as the type
"active": True, # Setting as active by default
}
)
.on_conflict_do_update(
index_elements=["id_name"],
set_=dict(
occurrences=KGRelationshipTypeExtractionStaging.occurrences
+ extraction_count,
),
)
.returning(KGRelationshipTypeExtractionStaging)
)
result = db_session.execute(stmt).scalar()
if result is None:
raise RuntimeError(
f"Failed to create or increment staging relationship type with id_name: {id_name}"
)
db_session.flush() # Flush to get any DB errors early
return result
def transfer_relationship_type(
db_session: Session,
relationship_type: KGRelationshipTypeExtractionStaging,
) -> KGRelationshipType:
"""
Transfer a relationship type from the staging table to the normalized table.
"""
stmt = (
pg_insert(KGRelationshipType)
.values(
id_name=relationship_type.id_name,
name=relationship_type.name,
source_entity_type_id_name=relationship_type.source_entity_type_id_name,
target_entity_type_id_name=relationship_type.target_entity_type_id_name,
definition=relationship_type.definition,
occurrences=relationship_type.occurrences,
type=relationship_type.type,
active=relationship_type.active,
)
.on_conflict_do_update(
index_elements=["id_name"],
set_=dict(
occurrences=KGRelationshipType.occurrences
+ relationship_type.occurrences,
),
)
.returning(KGRelationshipType)
)
new_relationship_type = db_session.execute(stmt).scalar()
if new_relationship_type is None:
raise RuntimeError(
f"Failed to transfer relationship type with id_name: {relationship_type.id_name}"
)
# Update transferred
db_session.query(KGRelationshipTypeExtractionStaging).filter(
KGRelationshipTypeExtractionStaging.id_name == relationship_type.id_name
).update({"transferred": True})
db_session.flush()
return new_relationship_type
def get_parent_child_relationships_and_types(
db_session: Session,
depth: int,
) -> tuple[
list[KGRelationshipExtractionStaging], list[KGRelationshipTypeExtractionStaging]
]:
"""
Create parent-child relationships and relationship types from staging entities with
a parent key, if the parent exists in the normalized entities table. Will create
relationships up to depth levels. E.g., if depth is 2, a relationship will be created
between the entity and its parent, and the entity and its grandparents (if any).
A relationship will not be created if the parent does not exist.
"""
relationship_types: dict[str, KGRelationshipTypeExtractionStaging] = {}
relationships: dict[tuple[str, str | None], KGRelationshipExtractionStaging] = {}
parented_entities = (
db_session.query(KGEntityExtractionStaging)
.filter(KGEntityExtractionStaging.parent_key.isnot(None))
.all()
)
# create has_subcomponent relationships and relationship types
for entity in parented_entities:
child = entity
if entity.transferred_id_name is None:
logger.warning(f"Entity {entity.id_name} has not yet been transferred")
continue
for i in range(depth):
if not child.parent_key:
break
# find the transferred parent entity
parent = (
db_session.query(KGEntity)
.filter(
KGEntity.entity_class == child.entity_class,
KGEntity.entity_key == child.parent_key,
)
.first()
)
if parent is None:
logger.warning(f"Parent entity not found for {entity.id_name}")
break
# create the relationship type
relationship_type = upsert_staging_relationship_type(
db_session=db_session,
source_entity_type=parent.entity_type_id_name,
relationship_type="has_subcomponent",
target_entity_type=entity.entity_type_id_name,
definition=False,
extraction_count=1,
)
relationship_types[relationship_type.id_name] = relationship_type
# create the relationship
# (don't add it to the table as we're using the transferred id, which breaks fk constraints)
relationship_id_name = make_relationship_id(
parent.id_name, "has_subcomponent", entity.transferred_id_name
)
if (parent.id_name, entity.document_id) not in relationships:
(
source_entity_id_name,
relationship_string,
target_entity_id_name,
) = split_relationship_id(relationship_id_name)
source_entity_type = get_entity_type(source_entity_id_name)
target_entity_type = get_entity_type(target_entity_id_name)
relationship_type_id_name = extract_relationship_type_id(
relationship_id_name
)
relationships[(relationship_id_name, entity.document_id)] = (
KGRelationshipExtractionStaging(
id_name=relationship_id_name,
source_node=source_entity_id_name,
target_node=target_entity_id_name,
source_node_type=source_entity_type,
target_node_type=target_entity_type,
type=relationship_string,
relationship_type_id_name=relationship_type_id_name,
source_document=entity.document_id,
occurrences=1,
)
)
else:
relationships[(parent.id_name, entity.document_id)].occurrences += 1
# set parent as the next child (unless we're at the max depth)
if i < depth - 1:
parent_staging = (
db_session.query(KGEntityExtractionStaging)
.filter(
KGEntityExtractionStaging.transferred_id_name == parent.id_name
)
.first()
)
if parent_staging is None:
break
child = parent_staging
return list(relationships.values()), list(relationship_types.values())
def delete_relationships_by_id_names(
db_session: Session, id_names: list[str], kg_stage: KGStage
) -> int:
"""
Delete relationships from the database based on a list of id_names.
Args:
db_session: SQLAlchemy database session
id_names: List of relationship id_names to delete
Returns:
Number of relationships deleted
Raises:
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
"""
deleted_count = 0
if kg_stage == KGStage.EXTRACTED:
deleted_count = (
db_session.query(KGRelationshipExtractionStaging)
.filter(KGRelationshipExtractionStaging.id_name.in_(id_names))
.delete(synchronize_session=False)
)
elif kg_stage == KGStage.NORMALIZED:
deleted_count = (
db_session.query(KGRelationship)
.filter(KGRelationship.id_name.in_(id_names))
.delete(synchronize_session=False)
)
db_session.flush() # Flush to ensure deletion is processed
return deleted_count
def delete_relationship_types_by_id_names(
db_session: Session, id_names: list[str], kg_stage: KGStage
) -> int:
"""
Delete relationship types from the database based on a list of id_names.
Args:
db_session: SQLAlchemy database session
id_names: List of relationship type id_names to delete
Returns:
Number of relationship types deleted
Raises:
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
"""
deleted_count = 0
if kg_stage == KGStage.EXTRACTED:
deleted_count = (
db_session.query(KGRelationshipTypeExtractionStaging)
.filter(KGRelationshipTypeExtractionStaging.id_name.in_(id_names))
.delete(synchronize_session=False)
)
elif kg_stage == KGStage.NORMALIZED:
deleted_count = (
db_session.query(KGRelationshipType)
.filter(KGRelationshipType.id_name.in_(id_names))
.delete(synchronize_session=False)
)
db_session.flush() # Flush to ensure deletion is processed
return deleted_count
def get_relationships_for_entity_type_pairs(
db_session: Session, entity_type_pairs: list[tuple[str, str]]
) -> list["KGRelationshipType"]:
"""
Get relationship types from the database based on a list of entity type pairs.
Args:
db_session: SQLAlchemy database session
entity_type_pairs: List of tuples where each tuple contains (source_entity_type, target_entity_type)
Returns:
List of KGRelationshipType objects where source and target types match the provided pairs
"""
conditions = [
(
(KGRelationshipType.source_entity_type_id_name == source_type)
& (KGRelationshipType.target_entity_type_id_name == target_type)
)
for source_type, target_type in entity_type_pairs
]
return db_session.query(KGRelationshipType).filter(or_(*conditions)).all()
def get_allowed_relationship_type_pairs(
db_session: Session, entities: list[str]
) -> list[str]:
"""
Get the allowed relationship pairs for the given entities.
Args:
db_session: SQLAlchemy database session
entities: List of entity type ID names to filter by
Returns:
List of id_names from KGRelationshipType where both source and target entity types
are in the provided entities list
"""
entity_types = list({get_entity_type(entity) for entity in entities})
return [
row[0]
for row in (
db_session.query(KGRelationshipType.id_name)
.filter(KGRelationshipType.source_entity_type_id_name.in_(entity_types))
.filter(KGRelationshipType.target_entity_type_id_name.in_(entity_types))
.distinct()
.all()
)
]
def get_relationships_of_entity(db_session: Session, entity_id: str) -> List[str]:
"""Get all relationship ID names where the given entity is either the source or target node.
Args:
db_session: SQLAlchemy session
entity_id: ID of the entity to find relationships for
Returns:
List of relationship ID names where the entity is either source or target
"""
return [
row[0]
for row in (
db_session.query(KGRelationship.id_name)
.filter(
or_(
KGRelationship.source_node == entity_id,
KGRelationship.target_node == entity_id,
)
)
.all()
)
]
def get_relationship_types_of_entity_types(
db_session: Session, entity_types_id: str
) -> List[str]:
"""Get all relationship ID names where the given entity is either the source or target node.
Args:
db_session: SQLAlchemy session
entity_types_id: ID of the entity to find relationships for
Returns:
List of relationship ID names where the entity is either source or target
"""
if entity_types_id.endswith(":*"):
entity_types_id = entity_types_id[:-2]
return [
row[0]
for row in (
db_session.query(KGRelationshipType.id_name)
.filter(
or_(
KGRelationshipType.source_entity_type_id_name == entity_types_id,
KGRelationshipType.target_entity_type_id_name == entity_types_id,
)
)
.all()
)
]
def delete_document_references_from_kg(db_session: Session, document_id: str) -> None:
# Delete relationships from normalized stage
db_session.query(KGRelationship).filter(
KGRelationship.source_document == document_id
).delete(synchronize_session=False)
# Delete relationships from extraction staging
db_session.query(KGRelationshipExtractionStaging).filter(
KGRelationshipExtractionStaging.source_document == document_id
).delete(synchronize_session=False)
# Delete entities from normalized stage
db_session.query(KGEntity).filter(KGEntity.document_id == document_id).delete(
synchronize_session=False
)
# Delete entities from extraction staging
db_session.query(KGEntityExtractionStaging).filter(
KGEntityExtractionStaging.document_id == document_id
).delete(synchronize_session=False)
db_session.flush()
def delete_from_kg_relationships_extraction_staging__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""Delete relationships from the extraction staging table."""
db_session.query(KGRelationshipExtractionStaging).filter(
KGRelationshipExtractionStaging.source_document.in_(document_ids)
).delete(synchronize_session=False)
def delete_from_kg_relationships__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""Delete relationships from the normalized table."""
db_session.query(KGRelationship).filter(
KGRelationship.source_document.in_(document_ids)
).delete(synchronize_session=False)

View File

@@ -91,6 +91,25 @@ schema {{ schema_name }} {
indexing: attribute
}
# Separate array fields for knowledge graph data
field kg_entities type weightedset<string> {
rank: filter
indexing: summary | attribute
attribute: fast-search
}
field kg_relationships type weightedset<string> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
field kg_terms type weightedset<string> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
# Needs to have a separate Attribute list for efficient filtering
field metadata_list type array<string> {
indexing: summary | attribute

View File

@@ -166,18 +166,19 @@ def _get_chunks_via_visit_api(
# build the list of fields to retrieve
field_set_list = (
None
if not field_names
else [f"{index_name}:{field_name}" for field_name in field_names]
[f"{field_name}" for field_name in field_names] if field_names else []
)
acl_fieldset_entry = f"{index_name}:{ACCESS_CONTROL_LIST}"
acl_fieldset_entry = f"{ACCESS_CONTROL_LIST}"
if (
field_set_list
and filters.access_control_list
and acl_fieldset_entry not in field_set_list
):
field_set_list.append(acl_fieldset_entry)
field_set = ",".join(field_set_list) if field_set_list else None
if field_set_list:
field_set = f"{index_name}:" + ",".join(field_set_list)
else:
field_set = None
# build filters
selection = f"{index_name}.document_id=='{chunk_request.document_id}'"

View File

@@ -18,6 +18,7 @@ from uuid import UUID
import httpx # type: ignore
import jinja2
import requests # type: ignore
from pydantic import BaseModel
from retry import retry
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
@@ -30,6 +31,7 @@ from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.document_index_utils import get_document_chunk_ids
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentInsertionRecord
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
@@ -86,6 +88,17 @@ httpx_logger = logging.getLogger("httpx")
httpx_logger.setLevel(logging.WARNING)
def update_kg_type_dict(
dict_to_update: dict[str, dict], kg_type: str, value_set: set[str]
) -> dict[str, dict]:
if "fields" not in dict_to_update:
dict_to_update["fields"] = {}
dict_to_update["fields"][kg_type] = {
"assign": {kg_type_object: 1 for kg_type_object in value_set}
}
return dict_to_update
@dataclass
class _VespaUpdateRequest:
document_id: str
@@ -93,6 +106,39 @@ class _VespaUpdateRequest:
update_request: dict[str, dict]
class KGVespaChunkUpdateRequest(BaseModel):
document_id: str
chunk_id: int
url: str
update_request: dict[str, dict]
class KGUChunkUpdateRequest(BaseModel):
"""
Update KG fields for a document
"""
document_id: str
chunk_id: int
core_entity: str
entities: set[str] | None = None
relationships: set[str] | None = None
terms: set[str] | None = None
converted_attributes: set[str] | None = None
attributes: dict[str, str | list[str]] | None = None
class KGUDocumentUpdateRequest(BaseModel):
"""
Update KG fields for a document
"""
document_id: str
entities: set[str]
relationships: set[str]
terms: set[str]
def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
@@ -501,6 +547,51 @@ class VespaIndex(DocumentIndex):
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
raise requests.HTTPError(failure_msg) from e
@classmethod
def _apply_kg_chunk_updates_batched(
cls,
updates: list[KGVespaChunkUpdateRequest],
httpx_client: httpx.Client,
batch_size: int = BATCH_SIZE,
) -> None:
"""Runs a batch of updates in parallel via the ThreadPoolExecutor."""
def _kg_update_chunk(
update: KGVespaChunkUpdateRequest, http_client: httpx.Client
) -> httpx.Response:
# logger.debug(
# f"Updating KG with request to {update.url} with body {update.update_request}"
# )
return http_client.put(
update.url,
headers={"Content-Type": "application/json"},
json=update.update_request,
)
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficient for
# indexing / updates / deletes since we have to make a large volume of requests.
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
httpx_client as http_client,
):
for update_batch in batch_generator(updates, batch_size):
future_to_document_id = {
executor.submit(
_kg_update_chunk,
update,
http_client,
): update.document_id
for update in update_batch
}
for future in concurrent.futures.as_completed(future_to_document_id):
res = future.result()
try:
res.raise_for_status()
except requests.HTTPError as e:
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
raise requests.HTTPError(failure_msg) from e
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
@@ -584,6 +675,63 @@ class VespaIndex(DocumentIndex):
time.monotonic() - update_start,
)
def kg_chunk_updates(
self, kg_update_requests: list[KGUChunkUpdateRequest], tenant_id: str
) -> None:
processed_updates_requests: list[KGVespaChunkUpdateRequest] = []
update_start = time.monotonic()
# Build the _VespaUpdateRequest objects
for kg_update_request in kg_update_requests:
kg_update_dict: dict[str, dict] = {"fields": {}}
if kg_update_request.relationships is not None:
kg_update_dict = update_kg_type_dict(
kg_update_dict, "kg_relationships", kg_update_request.relationships
)
if kg_update_request.entities is not None:
kg_update_dict = update_kg_type_dict(
kg_update_dict, "kg_entities", kg_update_request.entities
)
if kg_update_request.terms is not None:
kg_update_dict = update_kg_type_dict(
kg_update_dict, "kg_terms", kg_update_request.terms
)
if not kg_update_dict["fields"]:
logger.error("Update request received but nothing to update")
continue
doc_chunk_id = get_uuid_from_chunk_info(
document_id=kg_update_request.document_id,
chunk_id=kg_update_request.chunk_id,
tenant_id=tenant_id,
large_chunk_id=None,
)
processed_updates_requests.append(
KGVespaChunkUpdateRequest(
document_id=kg_update_request.document_id,
chunk_id=kg_update_request.chunk_id,
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
update_request=kg_update_dict,
)
)
with self.httpx_client_context as httpx_client:
self._apply_kg_chunk_updates_batched(
processed_updates_requests, httpx_client
)
logger.debug(
"Updated %d vespa documents in %.2f seconds",
len(processed_updates_requests),
time.monotonic() - update_start,
)
@retry(
tries=3,
delay=1,

View File

@@ -0,0 +1,78 @@
from retry import retry
from onyx.db.document import get_document_kg_entities_and_relationships
from onyx.db.engine import get_session_with_current_tenant
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
from onyx.document_index.vespa.index import IndexFilters
from onyx.document_index.vespa.index import KGUChunkUpdateRequest
from onyx.document_index.vespa.index import VespaIndex
from onyx.kg.utils.formatting_utils import generalize_entities
from onyx.kg.utils.formatting_utils import generalize_relationships
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@retry(tries=3, delay=1, backoff=2)
def update_kg_chunks_vespa_info(
kg_update_requests: list[KGUChunkUpdateRequest],
index_name: str,
tenant_id: str,
) -> None:
""" """
# Use the existing visit API infrastructure
vespa_index = VespaIndex(
index_name=index_name,
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=False,
multitenant=MULTI_TENANT,
httpx_client=None,
)
vespa_index.kg_chunk_updates(
kg_update_requests=kg_update_requests, tenant_id=tenant_id
)
def get_kg_vespa_info_update_requests_for_document(
document_id: str, index_name: str, tenant_id: str
) -> list[KGUChunkUpdateRequest]:
"""Get the kg_info update requests for a document."""
# get all entities and relationships tied to the document
with get_session_with_current_tenant() as db_session:
entities, relationships = get_document_kg_entities_and_relationships(
db_session, document_id
)
# create the kg vespa info
entity_id_names = [entity.id_name for entity in entities]
relationship_id_names = [relationship.id_name for relationship in relationships]
kg_entities = generalize_entities(entity_id_names) | set(entity_id_names)
kg_relationships = generalize_relationships(relationship_id_names) | set(
relationship_id_names
)
# get chunks in the document
chunks = _get_chunks_via_visit_api(
chunk_request=VespaChunkRequest(document_id=document_id),
index_name=index_name,
filters=IndexFilters(access_control_list=None, tenant_id=tenant_id),
field_names=["chunk_id"],
get_large_chunks=False,
)
# get vespa update requests
return [
KGUChunkUpdateRequest(
document_id=document_id,
chunk_id=chunk["fields"]["chunk_id"],
core_entity="unused",
entities=kg_entities,
relationships=kg_relationships or None,
)
for chunk in chunks
]

View File

@@ -54,6 +54,47 @@ def build_vespa_filters(
return result
def _build_kg_filter(
kg_entities: list[str] | None,
kg_relationships: list[str] | None,
kg_terms: list[str] | None,
) -> str:
if not kg_entities and not kg_relationships and not kg_terms:
return ""
filter_parts = []
# Process each filter type using the same pattern
for filter_type, values in [
("kg_entities", kg_entities),
("kg_relationships", kg_relationships),
("kg_terms", kg_terms),
]:
if values:
filter_parts.append(
" and ".join(f'({filter_type} contains "{val}") ' for val in values)
)
return f"({' and '.join(filter_parts)}) and "
def _build_kg_source_filters(
kg_sources: list[str] | None,
) -> str:
if not kg_sources:
return ""
source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources]
return f"({' or '.join(source_phrases)}) and "
def _build_kg_chunk_id_zero_only_filter(
kg_chunk_id_zero_only: bool,
) -> str:
if not kg_chunk_id_zero_only:
return ""
return "(chunk_id = 0 ) and "
def _build_time_filter(
cutoff: datetime | None,
untimed_doc_cutoff: timedelta = timedelta(days=92),
@@ -71,7 +112,9 @@ def build_vespa_filters(
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
# If running in multi-tenant mode
if filters.tenant_id and MULTI_TENANT:
if MULTI_TENANT:
if not filters.tenant_id:
raise ValueError("Tenant ID is required in multi-tenant mode")
filter_str += f'({TENANT_ID} contains "{filters.tenant_id}") and '
# ACL filters
@@ -106,6 +149,19 @@ def build_vespa_filters(
# Time filter
filter_str += _build_time_filter(filters.time_cutoff)
# Knowledge Graph Filters
filter_str += _build_kg_filter(
kg_entities=filters.kg_entities,
kg_relationships=filters.kg_relationships,
kg_terms=filters.kg_terms,
)
filter_str += _build_kg_source_filters(filters.kg_sources)
filter_str += _build_kg_chunk_id_zero_only_filter(
filters.kg_chunk_id_zero_only or False
)
# Trim trailing " and "
if remove_trailing_and and filter_str.endswith(" and "):
filter_str = filter_str[:-5]

View File

@@ -0,0 +1,279 @@
from typing import cast
from rapidfuzz.fuzz import ratio
from sqlalchemy import text
from onyx.configs.kg_configs import KG_CLUSTERING_RETRIEVE_THRESHOLD
from onyx.configs.kg_configs import KG_CLUSTERING_THRESHOLD
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import KGEntity
from onyx.db.entities import KGEntityExtractionStaging
from onyx.db.entities import merge_entities
from onyx.db.entities import transfer_entity
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.models import Document
from onyx.db.models import KGEntityType
from onyx.db.models import KGRelationshipExtractionStaging
from onyx.db.models import KGRelationshipTypeExtractionStaging
from onyx.db.relationships import get_parent_child_relationships_and_types
from onyx.db.relationships import transfer_relationship
from onyx.db.relationships import transfer_relationship_type
from onyx.document_index.vespa.kg_interactions import (
get_kg_vespa_info_update_requests_for_document,
)
from onyx.document_index.vespa.kg_interactions import update_kg_chunks_vespa_info
from onyx.kg.models import KGGroundingType
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
def _cluster_one_grounded_entity(
entity: KGEntityExtractionStaging,
) -> tuple[KGEntity, bool]:
"""
Cluster a single grounded entity.
"""
with get_session_with_current_tenant() as db_session:
# get entity name and filtering conditions
if entity.document_id is not None:
entity_name = cast(
str,
db_session.query(Document.semantic_id)
.filter(Document.id == entity.document_id)
.scalar(),
).lower()
filtering = [KGEntity.document_id.is_(None)]
else:
entity_name = entity.name.lower()
filtering = []
# skip those with numbers so we don't cluster version1 and version2, etc.
similar_entities: list[KGEntity] = []
if not any(char.isdigit() for char in entity_name):
# find similar entities, uses GIN index, very efficient
db_session.execute(
text(
"SET pg_trgm.similarity_threshold = "
+ str(KG_CLUSTERING_RETRIEVE_THRESHOLD)
)
)
similar_entities = (
db_session.query(KGEntity)
.filter(
# find entities of the same type with a similar name
*filtering,
KGEntity.entity_type_id_name == entity.entity_type_id_name,
KGEntity.name.op("%")(entity_name),
)
.all()
)
# find best match
best_score = -1.0
best_entity = None
for similar in similar_entities:
# skip those with numbers so we don't cluster version1 and version2, etc.
if any(char.isdigit() for char in similar.name):
continue
score = ratio(similar.name, entity_name)
if score >= KG_CLUSTERING_THRESHOLD * 100 and score > best_score:
best_score = score
best_entity = similar
# if there is a match, update the entity, otherwise create a new one
with get_session_with_current_tenant() as db_session:
if best_entity:
logger.debug(f"Merged {entity.name} with {best_entity.name}")
update_vespa = (
best_entity.document_id is None and entity.document_id is not None
)
transferred_entity = merge_entities(
db_session=db_session, parent=best_entity, child=entity
)
else:
update_vespa = entity.document_id is not None
transferred_entity = transfer_entity(db_session=db_session, entity=entity)
db_session.commit()
return transferred_entity, update_vespa
def _transfer_batch_relationship(
relationships: list[KGRelationshipExtractionStaging],
entity_translations: dict[str, str],
) -> set[str]:
updated_documents: set[str] = set()
with get_session_with_current_tenant() as db_session:
entity_id_names: set[str] = set()
for relationship in relationships:
transferred_relationship = transfer_relationship(
db_session=db_session,
relationship=relationship,
entity_translations=entity_translations,
)
entity_id_names.add(transferred_relationship.source_node)
entity_id_names.add(transferred_relationship.target_node)
updated_documents.update(
(
res[0]
for res in db_session.query(KGEntity.document_id)
.filter(KGEntity.id_name.in_(entity_id_names))
.all()
if res[0] is not None
)
)
db_session.commit()
return updated_documents
def kg_clustering(
tenant_id: str, index_name: str, processing_chunk_batch_size: int = 16
) -> None:
"""
Here we will cluster the extractions based on their cluster frameworks.
Initially, this will only focus on grounded entities with pre-determined
relationships, so 'clustering' is actually not yet required.
However, we may need to reconcile entities coming from different sources.
The primary purpose of this function is to populate the actual KG tables
from the temp_extraction tables.
This will change with deep extraction, where grounded-sourceless entities
can be extracted and then need to be clustered.
"""
logger.info(f"Starting kg clustering for tenant {tenant_id}")
with get_session_with_current_tenant() as db_session:
kg_config_settings = get_kg_config_settings(db_session)
# Retrieve staging data
with get_session_with_current_tenant() as db_session:
untransferred_relationship_types = (
db_session.query(KGRelationshipTypeExtractionStaging)
.filter(KGRelationshipTypeExtractionStaging.transferred.is_(False))
.all()
)
untransferred_relationships = (
db_session.query(KGRelationshipExtractionStaging)
.filter(KGRelationshipExtractionStaging.transferred.is_(False))
.all()
)
grounded_entities = (
db_session.query(KGEntityExtractionStaging)
.join(
KGEntityType,
KGEntityExtractionStaging.entity_type_id_name == KGEntityType.id_name,
)
.filter(KGEntityType.grounding == KGGroundingType.GROUNDED)
.all()
)
# Cluster and transfer grounded entities
untransferred_grounded_entities = [
entity for entity in grounded_entities if entity.transferred_id_name is None
]
entity_translations: dict[str, str] = {
entity.id_name: entity.transferred_id_name
for entity in grounded_entities
if entity.transferred_id_name is not None
}
vespa_update_documents: set[str] = set()
for entity in untransferred_grounded_entities:
added_entity, update_vespa = _cluster_one_grounded_entity(entity)
entity_translations[entity.id_name] = added_entity.id_name
if update_vespa and added_entity.document_id is not None:
vespa_update_documents.add(added_entity.document_id)
logger.info(f"Transferred {len(untransferred_grounded_entities)} entities")
# Add parent-child relationships and relationship types
with get_session_with_current_tenant() as db_session:
parent_child_relationships, parent_child_relationship_types = (
get_parent_child_relationships_and_types(
db_session, depth=kg_config_settings.KG_MAX_PARENT_RECURSION_DEPTH
)
)
untransferred_relationship_types.extend(parent_child_relationship_types)
untransferred_relationships.extend(parent_child_relationships)
db_session.commit()
# Transfer the relationship types
for relationship_type in untransferred_relationship_types:
with get_session_with_current_tenant() as db_session:
transfer_relationship_type(db_session, relationship_type=relationship_type)
db_session.commit()
logger.info(
f"Transferred {len(untransferred_relationship_types)} relationship types"
)
# Transfer relationships in parallel
updated_documents_batch: list[set[str]] = run_functions_tuples_in_parallel(
[
(
_transfer_batch_relationship,
(
untransferred_relationships[
batch_i : batch_i + processing_chunk_batch_size
],
entity_translations,
),
)
for batch_i in range(
0, len(untransferred_relationships), processing_chunk_batch_size
)
]
)
for updated_documents in updated_documents_batch:
vespa_update_documents.update(updated_documents)
logger.info(f"Transferred {len(untransferred_relationships)} relationships")
# Update vespa for documents that had their kg info updated in parallel
for i in range(0, len(vespa_update_documents), processing_chunk_batch_size):
batch_update_requests = run_functions_tuples_in_parallel(
[
(
get_kg_vespa_info_update_requests_for_document,
(document_id, index_name, tenant_id),
)
for document_id in list(vespa_update_documents)[
i : i + processing_chunk_batch_size
]
]
)
for update_requests in batch_update_requests:
update_kg_chunks_vespa_info(update_requests, index_name, tenant_id)
# Delete the transferred objects from the staging tables
try:
with get_session_with_current_tenant() as db_session:
db_session.query(KGRelationshipExtractionStaging).filter(
KGRelationshipExtractionStaging.transferred.is_(True)
).delete(synchronize_session=False)
db_session.commit()
except Exception as e:
logger.error(f"Error deleting relationships: {e}")
try:
with get_session_with_current_tenant() as db_session:
db_session.query(KGRelationshipTypeExtractionStaging).filter(
KGRelationshipTypeExtractionStaging.transferred.is_(True)
).delete(synchronize_session=False)
db_session.commit()
except Exception as e:
logger.error(f"Error deleting relationship types: {e}")
try:
with get_session_with_current_tenant() as db_session:
db_session.query(KGEntityExtractionStaging).filter(
KGEntityExtractionStaging.transferred_id_name.is_not(None)
).delete(synchronize_session=False)
db_session.commit()
except Exception as e:
logger.error(f"Error deleting entities: {e}")

View File

@@ -0,0 +1,403 @@
import re
from collections import defaultdict
from typing import cast
import numpy as np
from nltk import ngrams # type: ignore
from rapidfuzz.distance.DamerauLevenshtein import normalized_similarity
from sqlalchemy import desc
from sqlalchemy import Float
from sqlalchemy import func
from sqlalchemy import MetaData
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy.dialects.postgresql import ARRAY
from onyx.configs.kg_configs import KG_NORMALIZATION_RERANK_LEVENSHTEIN_WEIGHT
from onyx.configs.kg_configs import KG_NORMALIZATION_RERANK_NGRAM_WEIGHTS
from onyx.configs.kg_configs import KG_NORMALIZATION_RERANK_THRESHOLD
from onyx.configs.kg_configs import KG_NORMALIZATION_RETRIEVE_ENTITIES_LIMIT
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import KGEntity
from onyx.db.relationships import get_relationships_for_entity_type_pairs
from onyx.kg.models import NormalizedEntities
from onyx.kg.models import NormalizedRelationships
from onyx.kg.models import NormalizedTerms
from onyx.kg.utils.embeddings import encode_string_batch
from onyx.kg.utils.formatting_utils import format_entity_id_for_models
from onyx.kg.utils.formatting_utils import get_entity_type
from onyx.kg.utils.formatting_utils import make_relationship_id
from onyx.kg.utils.formatting_utils import split_entity_id
from onyx.kg.utils.formatting_utils import split_relationship_id
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
alphanum_regex = re.compile(r"[^a-z0-9]+")
rem_email_regex = re.compile(r"(?<=\S)@([a-z0-9-]+)\.([a-z]{2,6})$")
def _clean_name(entity_name: str) -> str:
"""
Clean an entity string by removing non-alphanumeric characters and email addresses.
If the name after cleaning is empty, return the original name in lowercase.
"""
cleaned_entity = entity_name.casefold()
return (
alphanum_regex.sub("", rem_email_regex.sub("", cleaned_entity))
or cleaned_entity
)
def _normalize_one_entity(
entity: str, allowed_docs_temp_view_name: str | None = None
) -> str | None:
"""
Matches a single entity to the best matching entity of the same type.
"""
entity_type, entity_name = split_entity_id(entity)
if entity_name == "*":
return entity
cleaned_entity = _clean_name(entity_name)
# step 1: find entities containing the entity_name or something similar
with get_session_with_current_tenant() as db_session:
# get allowed documents
metadata = MetaData()
if allowed_docs_temp_view_name is None:
raise ValueError("allowed_docs_temp_view_name is not available")
allowed_docs_temp_view = Table(
allowed_docs_temp_view_name,
metadata,
autoload_with=db_session.get_bind(),
)
# generate trigrams of the queried entity Q
query_trigrams = db_session.query(
func.show_trgm(cleaned_entity).cast(ARRAY(String(3))).label("trigrams")
).cte("query")
candidates = cast(
list[tuple[str, str, float]],
db_session.query(
KGEntity.id_name,
KGEntity.name,
(
# for each entity E, compute score = | Q ∩ E | / min(|Q|, |E|)
func.cardinality(
func.array(
select(func.unnest(KGEntity.name_trigrams))
.correlate(KGEntity)
.intersect(
select(
func.unnest(query_trigrams.c.trigrams)
).correlate(query_trigrams)
)
.scalar_subquery()
)
).cast(Float)
/ func.least(
func.cardinality(query_trigrams.c.trigrams),
func.cardinality(KGEntity.name_trigrams),
)
).label("score"),
)
.select_from(KGEntity, query_trigrams)
.outerjoin(
allowed_docs_temp_view,
KGEntity.document_id == allowed_docs_temp_view.c.allowed_doc_id,
)
.filter(
KGEntity.entity_type_id_name == entity_type,
KGEntity.name_trigrams.overlap(query_trigrams.c.trigrams),
# Add filter for allowed docs - either document_id is NULL or it's in allowed_docs
(
KGEntity.document_id.is_(None)
| allowed_docs_temp_view.c.allowed_doc_id.isnot(None)
),
)
.order_by(desc("score"))
.limit(KG_NORMALIZATION_RETRIEVE_ENTITIES_LIMIT)
.all(),
)
if not candidates:
return None
# step 2: do a weighted ngram analysis and damerau levenshtein distance to rerank
n1, n2, n3 = (
set(ngrams(cleaned_entity, 1)),
set(ngrams(cleaned_entity, 2)),
set(ngrams(cleaned_entity, 3)),
)
for i, (candidate_id_name, candidate_name, _) in enumerate(candidates):
cleaned_candidate = _clean_name(candidate_name)
h_n1, h_n2, h_n3 = (
set(ngrams(cleaned_candidate, 1)),
set(ngrams(cleaned_candidate, 2)),
set(ngrams(cleaned_candidate, 3)),
)
# compute ngram overlap, renormalize scores if the names are too short for larger ngrams
grams_used = min(2, len(cleaned_entity) - 1, len(cleaned_candidate) - 1)
W_n1, W_n2, W_n3 = KG_NORMALIZATION_RERANK_NGRAM_WEIGHTS
ngram_score = (
# compute | Q ∩ E | / min(|Q|, |E|) for unigrams and bigrams (trigrams already computed)
W_n1 * len(n1 & h_n1) / max(1, min(len(n1), len(h_n1)))
+ W_n2 * len(n2 & h_n2) / max(1, min(len(n2), len(h_n2)))
+ W_n3 * len(n3 & h_n3) / max(1, min(len(n3), len(h_n3)))
) / (W_n1, W_n1 + W_n2, 1.0)[grams_used]
# compute damerau levenshtein distance to fuzzy match against typos
W_leven = KG_NORMALIZATION_RERANK_LEVENSHTEIN_WEIGHT
leven_score = normalized_similarity(cleaned_entity, cleaned_candidate)
# combine scores
score = (1.0 - W_leven) * ngram_score + W_leven * leven_score
candidates[i] = (candidate_id_name, candidate_name, score)
candidates = list(
sorted(
filter(lambda x: x[2] > KG_NORMALIZATION_RERANK_THRESHOLD, candidates),
key=lambda x: x[2],
reverse=True,
)
)
if not candidates:
return None
return candidates[0][0]
def _get_existing_normalized_relationships(
raw_relationships: list[str],
) -> dict[str, dict[str, list[str]]]:
"""
Get existing normalized relationships from the database.
"""
relationship_type_map: dict[str, dict[str, list[str]]] = defaultdict(
lambda: defaultdict(list)
)
relationship_pairs = list(
{
(
get_entity_type(split_relationship_id(relationship)[0]),
get_entity_type(split_relationship_id(relationship)[2]),
)
for relationship in raw_relationships
}
)
with get_session_with_current_tenant() as db_session:
relationships = get_relationships_for_entity_type_pairs(
db_session, relationship_pairs
)
for relationship in relationships:
relationship_type_map[relationship.source_entity_type_id_name][
relationship.target_entity_type_id_name
].append(relationship.id_name)
return relationship_type_map
def normalize_entities(
raw_entities_no_attributes: list[str],
allowed_docs_temp_view_name: str | None = None,
) -> NormalizedEntities:
"""
Match each entity against a list of normalized entities using fuzzy matching.
Returns the best matching normalized entity for each input entity.
Args:
raw_entities_no_attributes: list of entity strings to normalize, w/o attributes
Returns:
list of normalized entity strings
"""
normalized_results: list[str] = []
normalized_map: dict[str, str] = {}
mapping: list[str | None] = run_functions_tuples_in_parallel(
[
(_normalize_one_entity, (entity, allowed_docs_temp_view_name))
for entity in raw_entities_no_attributes
]
)
for entity, normalized_entity in zip(raw_entities_no_attributes, mapping):
if normalized_entity is not None:
normalized_results.append(normalized_entity)
normalized_map[format_entity_id_for_models(entity)] = normalized_entity
else:
normalized_map[format_entity_id_for_models(entity)] = entity
return NormalizedEntities(
entities=normalized_results, entity_normalization_map=normalized_map
)
def normalize_entities_w_attributes_from_map(
raw_entities_w_attributes: list[str],
entity_normalization_map: dict[str, str],
) -> list[str]:
"""
Normalize entities with attributes using the entity normalization map.
"""
normalized_entities_w_attributes: list[str] = []
for raw_entities_w_attribute in raw_entities_w_attributes:
assert (
len(raw_entities_w_attribute.split("--")) == 2
), f"Invalid entity with attributes: {raw_entities_w_attribute}"
raw_entity, attributes = raw_entities_w_attribute.split("--")
formatted_raw_entity = format_entity_id_for_models(raw_entity)
normalized_entity = entity_normalization_map.get(formatted_raw_entity)
if normalized_entity is None:
logger.warning(f"No normalized entity found for {raw_entity}")
continue
else:
normalized_entities_w_attributes.append(
f"{normalized_entity}--{raw_entities_w_attribute.split('--')[1].strip()}"
)
return normalized_entities_w_attributes
def normalize_relationships(
raw_relationships: list[str], entity_normalization_map: dict[str, str]
) -> NormalizedRelationships:
"""
Normalize relationships using entity mappings and relationship string matching.
Args:
relationships: list of relationships in format "source__relation__target"
entity_normalization_map: Mapping of raw entities to normalized ones (or None)
Returns:
NormalizedRelationships containing normalized relationships and mapping
"""
# Placeholder for normalized relationship structure
nor_relationships = _get_existing_normalized_relationships(raw_relationships)
normalized_rels: list[str] = []
normalization_map: dict[str, str] = {}
for raw_rel in raw_relationships:
# 1. Split and normalize entities
try:
source, rel_string, target = split_relationship_id(raw_rel)
except ValueError:
raise ValueError(f"Invalid relationship format: {raw_rel}")
# Check if entities are in normalization map and not None
norm_source = entity_normalization_map.get(source)
norm_target = entity_normalization_map.get(target)
if norm_source is None or norm_target is None:
logger.warning(f"No normalized entities found for {raw_rel}")
continue
# 2. Find candidate normalized relationships
candidate_rels = []
norm_source_type = get_entity_type(norm_source)
norm_target_type = get_entity_type(norm_target)
if (
norm_source_type in nor_relationships
and norm_target_type in nor_relationships[norm_source_type]
):
candidate_rels = [
split_relationship_id(rel)[1]
for rel in nor_relationships[norm_source_type][norm_target_type]
]
if not candidate_rels:
logger.warning(f"No candidate relationships found for {raw_rel}")
continue
# 3. Encode and find best match
strings_to_encode = [rel_string] + candidate_rels
vectors = encode_string_batch(strings_to_encode)
# Get raw relation vector and candidate vectors
raw_vector = vectors[0]
candidate_vectors = vectors[1:]
# Calculate dot products
dot_products = np.dot(candidate_vectors, raw_vector)
best_match_idx = np.argmax(dot_products)
# Create normalized relationship
norm_rel = make_relationship_id(
norm_source, candidate_rels[best_match_idx], norm_target
)
normalized_rels.append(norm_rel)
normalization_map[raw_rel] = norm_rel
return NormalizedRelationships(
relationships=normalized_rels, relationship_normalization_map=normalization_map
)
def normalize_terms(raw_terms: list[str]) -> NormalizedTerms:
"""
Normalize terms using semantic similarity matching.
Args:
terms: list of terms to normalize
Returns:
NormalizedTerms containing normalized terms and mapping
"""
# # Placeholder for normalized terms - this would typically come from a predefined list
# normalized_term_list = [
# "algorithm",
# "database",
# "software",
# "programming",
# # ... other normalized terms ...
# ]
# normalized_terms: list[str] = []
# normalization_map: dict[str, str | None] = {}
# if not raw_terms:
# return NormalizedTerms(terms=[], term_normalization_map={})
# # Encode all terms at once for efficiency
# strings_to_encode = raw_terms + normalized_term_list
# vectors = encode_string_batch(strings_to_encode)
# # Split vectors into query terms and candidate terms
# query_vectors = vectors[:len(raw_terms)]
# candidate_vectors = vectors[len(raw_terms):]
# # Calculate similarity for each term
# for i, term in enumerate(raw_terms):
# # Calculate dot products with all candidates
# similarities = np.dot(candidate_vectors, query_vectors[i])
# best_match_idx = np.argmax(similarities)
# best_match_score = similarities[best_match_idx]
# # Use a threshold to determine if the match is good enough
# if best_match_score > 0.7: # Adjust threshold as needed
# normalized_term = normalized_term_list[best_match_idx]
# normalized_terms.append(normalized_term)
# normalization_map[term] = normalized_term
# else:
# # If no good match found, keep original
# normalization_map[term] = None
# return NormalizedTerms(
# terms=normalized_terms,
# term_normalization_map=normalization_map
# )
return NormalizedTerms(
terms=raw_terms, term_normalization_map={term: term for term in raw_terms}
)

View File

@@ -0,0 +1,53 @@
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entity_type import populate_default_employee_account_information
from onyx.db.entity_type import (
populate_default_primary_grounded_entity_type_information,
)
from onyx.db.kg_config import get_kg_enablement
from onyx.db.kg_config import KGConfigSettings
from onyx.utils.logger import setup_logger
logger = setup_logger()
def populate_default_grounded_entity_types() -> None:
with get_session_with_current_tenant() as db_session:
if not get_kg_enablement(db_session):
logger.error(
"KG approach is not enabled, the entity types cannot be populated."
)
raise ValueError(
"KG approach is not enabled, the entity types cannot be populated."
)
populate_default_primary_grounded_entity_type_information(db_session)
db_session.commit()
return None
def populate_default_account_employee_definitions() -> None:
with get_session_with_current_tenant() as db_session:
if not get_kg_enablement(db_session):
logger.error(
"KG approach is not enabled, the entity types cannot be populated."
)
raise ValueError(
"KG approach is not enabled, the entity types cannot be populated."
)
populate_default_employee_account_information(db_session)
db_session.commit()
return None
def validate_kg_settings(kg_config_settings: KGConfigSettings) -> None:
if not kg_config_settings.KG_ENABLED:
raise ValueError("KG is not enabled")
if not kg_config_settings.KG_VENDOR:
raise ValueError("KG_VENDOR is not set")
if not kg_config_settings.KG_VENDOR_DOMAINS:
raise ValueError("KG_VENDOR_DOMAINS is not set")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,100 @@
from pydantic import BaseModel
from onyx.kg.models import KGDefaultEntityDefinition
from onyx.kg.models import KGGroundingType
class KGDefaultPrimaryGroundedEntityDefinitions(BaseModel):
LINEAR: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A formal ticket about a product issue or improvement request.",
grounding=KGGroundingType.GROUNDED,
grounded_source_name="linear",
)
GITHUB_PR: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="Our (---vendor_name---) Engineering PRs describing what was actually implemented.",
grounding=KGGroundingType.GROUNDED,
grounded_source_name="github",
)
FIREFLIES: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A phone call transcript between us (---vendor_name---) \
and another account or individuals, or an internal meeting.",
grounding=KGGroundingType.GROUNDED,
grounded_source_name="fireflies",
)
GONG: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A phone call transcript between us (---vendor_name---) \
and another account or individuals, or an internal meeting.",
grounding=KGGroundingType.GROUNDED,
grounded_source_name="gong",
)
GOOGLE_DRIVE: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A Google Drive document.",
grounding=KGGroundingType.GROUNDED,
grounded_source_name="google_drive",
)
GMAIL: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="An email.",
grounding=KGGroundingType.GROUNDED,
grounded_source_name="gmail",
)
JIRA: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A formal JIRA ticket about a product issue or improvement request.",
grounding=KGGroundingType.GROUNDED,
grounded_source_name="jira",
)
ACCOUNT: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A company that was, is, or potentially could be a customer of the vendor \
('us, ---vendor_name---'). Note that ---vendor_name--- can never be an ACCOUNT.",
attributes={
"metadata_attributes": {},
"entity_filter_attributes": {"object_type": "Account"},
"classification_attributes": {},
},
grounding=KGGroundingType.GROUNDED,
grounded_source_name="salesforce",
)
OPPORTUNITY: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A sales opportunity.",
attributes={
"metadata_attributes": {},
"entity_filter_attributes": {"object_type": "Opportunity"},
"classification_attributes": {},
},
grounding=KGGroundingType.GROUNDED,
grounded_source_name="salesforce",
)
class KGDefaultAccountEmployeeDefinitions(BaseModel):
VENDOR: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="The Vendor ---vendor_name---, 'us'",
grounding=KGGroundingType.GROUNDED,
active=False,
grounded_source_name=None,
)
ACCOUNT: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A company that was, is, or potentially could be a customer of the vendor \
('us, ---vendor_name---'). Note that ---vendor_name--- can never be an ACCOUNT.",
grounding=KGGroundingType.GROUNDED,
active=False,
grounded_source_name=None,
)
EMPLOYEE: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A person who speaks on \
behalf of 'our' company (the VENDOR ---vendor_name---), NOT of another account. Therefore, employees of other companies \
are NOT included here. If in doubt, do NOT extract.",
grounding=KGGroundingType.GROUNDED,
active=False,
grounded_source_name=None,
)

223
backend/onyx/kg/models.py Normal file
View File

@@ -0,0 +1,223 @@
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel
from onyx.configs.kg_configs import KG_DEFAULT_MAX_PARENT_RECURSION_DEPTH
class KGConfigSettings(BaseModel):
KG_EXPOSED: bool = False
KG_ENABLED: bool = False
KG_VENDOR: str | None = None
KG_VENDOR_DOMAINS: list[str] | None = None
KG_IGNORE_EMAIL_DOMAINS: list[str] | None = None
KG_EXTRACTION_IN_PROGRESS: bool = False
KG_CLUSTERING_IN_PROGRESS: bool = False
KG_COVERAGE_START: datetime = datetime(1970, 1, 1)
KG_MAX_COVERAGE_DAYS: int = 10000
KG_MAX_PARENT_RECURSION_DEPTH: int = KG_DEFAULT_MAX_PARENT_RECURSION_DEPTH
class KGConfigVars(str, Enum):
KG_EXPOSED = "KG_EXPOSED"
KG_ENABLED = "KG_ENABLED"
KG_VENDOR = "KG_VENDOR"
KG_VENDOR_DOMAINS = "KG_VENDOR_DOMAINS"
KG_IGNORE_EMAIL_DOMAINS = "KG_IGNORE_EMAIL_DOMAINS"
KG_EXTRACTION_IN_PROGRESS = "KG_EXTRACTION_IN_PROGRESS"
KG_CLUSTERING_IN_PROGRESS = "KG_CLUSTERING_IN_PROGRESS"
KG_COVERAGE_START = "KG_COVERAGE_START"
KG_MAX_COVERAGE_DAYS = "KG_MAX_COVERAGE_DAYS"
KG_MAX_PARENT_RECURSION_DEPTH = "KG_MAX_PARENT_RECURSION_DEPTH"
class KGChunkFormat(BaseModel):
connector_id: int | None = None
document_id: str
chunk_id: int
title: str
content: str
primary_owners: list[str]
secondary_owners: list[str]
source_type: str
metadata: dict[str, str | list[str]] | None = None
entities: dict[str, int] = {}
relationships: dict[str, int] = {}
terms: dict[str, int] = {}
deep_extraction: bool = False
class KGChunkExtraction(BaseModel):
connector_id: int
document_id: str
chunk_id: int
core_entity: str
entities: list[str]
relationships: list[str]
terms: list[str]
attributes: dict[str, str | list[str]]
class KGChunkId(BaseModel):
connector_id: int | None = None
document_id: str
chunk_id: int
class KGRelationshipExtraction(BaseModel):
relationship_str: str
source_document_id: str
class KGAggregatedExtractions(BaseModel):
grounded_entities_document_ids: dict[str, str]
entities: dict[str, int]
relationships: dict[str, dict[str, int]]
terms: dict[str, int]
attributes: dict[str, dict[str, str | list[str]]]
class KGBatchExtractionStats(BaseModel):
connector_id: int | None = None
succeeded: list[KGChunkId]
failed: list[KGChunkId]
aggregated_kg_extractions: KGAggregatedExtractions
class ConnectorExtractionStats(BaseModel):
connector_id: int
num_succeeded: int
num_failed: int
num_processed: int
class KGPerson(BaseModel):
name: str
company: str
employee: bool
class NormalizedEntities(BaseModel):
entities: list[str]
entity_normalization_map: dict[str, str]
class NormalizedRelationships(BaseModel):
relationships: list[str]
relationship_normalization_map: dict[str, str]
class NormalizedTerms(BaseModel):
terms: list[str]
term_normalization_map: dict[str, str | None]
class KGClassificationContent(BaseModel):
document_id: str
classification_content: str
source_type: str
source_metadata: dict[str, Any] | None = None
entity_type: str | None = None
metadata: dict[str, Any] | None = None
class KGEnrichedClassificationContent(KGClassificationContent):
classification_enabled: bool
classification_instructions: dict[str, Any]
deep_extraction: bool
class KGClassificationDecisions(BaseModel):
document_id: str
classification_decision: bool
classification_class: str | None
source_metadata: dict[str, Any] | None = None
class KGClassificationInstructions(BaseModel):
classification_enabled: bool
classification_options: str
classification_class_definitions: dict[str, dict[str, str | bool]]
class KGExtractionInstructions(BaseModel):
deep_extraction: bool
active: bool
class KGEntityTypeInstructions(BaseModel):
classification_instructions: KGClassificationInstructions
extraction_instructions: KGExtractionInstructions
filter_instructions: dict[str, Any] | None = None
class KGEnhancedDocumentMetadata(BaseModel):
entity_type: str | None
document_attributes: dict[str, Any] | None
deep_extraction: bool
classification_enabled: bool
classification_instructions: KGClassificationInstructions | None
skip: bool
class ContextPreparation(BaseModel):
"""
Context preparation format for the LLM KG extraction.
"""
llm_context: str
core_entity: str
implied_entities: list[str]
implied_relationships: list[str]
implied_terms: list[str]
class KGDocumentClassificationPrompt(BaseModel):
"""
Document classification prompt format for the LLM KG extraction.
"""
llm_prompt: str | None
class KGConnectorData(BaseModel):
id: int
source: str
kg_coverage_days: int | None
class KGStage(str, Enum):
EXTRACTED = "extracted"
NORMALIZED = "normalized"
FAILED = "failed"
SKIPPED = "skipped"
NOT_STARTED = "not_started"
EXTRACTING = "extracting"
DO_NOT_EXTRACT = "do_not_extract"
class KGDocumentEntitiesRelationshipsAttributes(BaseModel):
kg_core_document_id_name: str
implied_entities: set[str]
implied_relationships: set[str]
converted_relationships_to_attributes: dict[str, list[str]]
company_participant_emails: set[str]
account_participant_emails: set[str]
converted_attributes_to_relationships: set[str]
document_attributes: dict[str, Any] | None
class KGGroundingType(str, Enum):
UNGROUNDED = "ungrounded"
GROUNDED = "grounded"
class KGDefaultEntityDefinition(BaseModel):
description: str
grounding: KGGroundingType
active: bool = False
grounded_source_name: str | None
attributes: dict = {}
entity_values: dict = {}

View File

@@ -0,0 +1,21 @@
from onyx.db.document import update_document_kg_stages
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import KGEntityExtractionStaging
from onyx.db.models import KGRelationshipExtractionStaging
from onyx.db.models import KGRelationshipTypeExtractionStaging
from onyx.kg.models import KGStage
def reset_extraction_kg_index() -> None:
"""
Resets the knowledge graph index.
"""
with get_session_with_current_tenant() as db_session:
db_session.query(KGRelationshipExtractionStaging).delete()
db_session.query(KGEntityExtractionStaging).delete()
db_session.query(KGRelationshipTypeExtractionStaging).delete()
db_session.commit()
with get_session_with_current_tenant() as db_session:
update_document_kg_stages(db_session, KGStage.EXTRACTED, KGStage.NOT_STARTED)
db_session.commit()

View File

@@ -0,0 +1,26 @@
from onyx.db.document import reset_all_document_kg_stages
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import KGEntity
from onyx.db.models import KGEntityExtractionStaging
from onyx.db.models import KGRelationship
from onyx.db.models import KGRelationshipExtractionStaging
from onyx.db.models import KGRelationshipType
from onyx.db.models import KGRelationshipTypeExtractionStaging
def reset_full_kg_index() -> None:
"""
Resets the knowledge graph index.
"""
with get_session_with_current_tenant() as db_session:
db_session.query(KGRelationship).delete()
db_session.query(KGRelationshipType).delete()
db_session.query(KGEntity).delete()
db_session.query(KGRelationshipExtractionStaging).delete()
db_session.query(KGEntityExtractionStaging).delete()
db_session.query(KGRelationshipTypeExtractionStaging).delete()
db_session.commit()
with get_session_with_current_tenant() as db_session:
reset_all_document_kg_stages(db_session)
db_session.commit()

View File

@@ -0,0 +1,22 @@
from onyx.db.document import update_document_kg_stages
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import KGEntity
from onyx.db.models import KGRelationship
from onyx.db.models import KGRelationshipType
from onyx.kg.models import KGStage
def reset_normalization_kg_index() -> None:
"""
Resets the knowledge graph index.
"""
with get_session_with_current_tenant() as db_session:
db_session.query(KGRelationship).delete()
db_session.query(KGEntity).delete()
db_session.query(KGRelationshipType).delete()
db_session.commit()
with get_session_with_current_tenant() as db_session:
update_document_kg_stages(db_session, KGStage.NORMALIZED, KGStage.EXTRACTED)
db_session.commit()

View File

@@ -0,0 +1,92 @@
from sqlalchemy import or_
from onyx.configs.constants import DocumentSource
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import Connector
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import KGEntity
from onyx.db.models import KGEntityExtractionStaging
from onyx.db.models import KGEntityType
from onyx.db.models import KGRelationship
from onyx.db.models import KGRelationshipExtractionStaging
from onyx.db.models import KGRelationshipType
from onyx.db.models import KGRelationshipTypeExtractionStaging
from onyx.db.models import KGStage
from onyx.kg.resets.reset_vespa import reset_vespa_kg_index
def reset_source_kg_index(source_name: str, tenant_id: str, index_name: str) -> None:
"""
Resets the knowledge graph index and vespa for a source.
"""
# reset vespa for the source
reset_vespa_kg_index(tenant_id, index_name, source_name)
with get_session_with_current_tenant() as db_session:
# get all the entity types for the given source
entity_types = [
et.id_name
for et in db_session.query(KGEntityType)
.filter(KGEntityType.grounded_source_name == source_name)
.all()
]
if not entity_types:
raise ValueError(f"There are no entity types for the source {source_name}")
# delete the entity type from the knowledge graph
for entity_type in entity_types:
db_session.query(KGRelationship).filter(
or_(
KGRelationship.source_node_type == entity_type,
KGRelationship.target_node_type == entity_type,
)
).delete()
db_session.query(KGRelationshipType).filter(
or_(
KGRelationshipType.source_entity_type_id_name == entity_type,
KGRelationshipType.target_entity_type_id_name == entity_type,
)
).delete()
db_session.query(KGEntity).filter(
KGEntity.entity_type_id_name == entity_type
).delete()
db_session.query(KGRelationshipExtractionStaging).filter(
or_(
KGRelationshipExtractionStaging.source_node_type == entity_type,
KGRelationshipExtractionStaging.target_node_type == entity_type,
)
).delete()
db_session.query(KGEntityExtractionStaging).filter(
KGEntityExtractionStaging.entity_type_id_name == entity_type
).delete()
db_session.query(KGRelationshipTypeExtractionStaging).filter(
or_(
KGRelationshipTypeExtractionStaging.source_entity_type_id_name
== entity_type,
KGRelationshipTypeExtractionStaging.target_entity_type_id_name
== entity_type,
)
).delete()
db_session.commit()
with get_session_with_current_tenant() as db_session:
# get all the documents for the given source
kg_connectors = [
connector.id
for connector in db_session.query(Connector)
.filter(Connector.source == DocumentSource(source_name))
.all()
]
document_ids = [
cc_pair.id
for cc_pair in db_session.query(DocumentByConnectorCredentialPair)
.filter(DocumentByConnectorCredentialPair.connector_id.in_(kg_connectors))
.all()
]
# reset the kg stage for the documents
db_session.query(Document).filter(Document.id.in_(document_ids)).update(
{"kg_stage": KGStage.NOT_STARTED}
)
db_session.commit()

View File

@@ -0,0 +1,122 @@
from typing import Any
from retry import retry
from onyx.configs.constants import DocumentSource
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import Connector
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import KGEntityType
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
from onyx.document_index.vespa.index import IndexFilters
from onyx.document_index.vespa.index import KGVespaChunkUpdateRequest
from onyx.document_index.vespa.index import VespaIndex
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@retry(tries=3, delay=1, backoff=2)
def _reset_vespa_for_doc(document_id: str, tenant_id: str, index_name: str) -> None:
vespa_index = VespaIndex(
index_name=index_name,
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=False,
multitenant=MULTI_TENANT,
httpx_client=None,
)
reset_update_dict: dict[str, Any] = {
"fields": {
"kg_entities": {"assign": {}},
"kg_relationships": {"assign": {}},
"kg_terms": {"assign": {}},
}
}
chunks = _get_chunks_via_visit_api(
VespaChunkRequest(document_id=document_id),
index_name,
IndexFilters(access_control_list=None),
["chunk_id"],
False,
)
vespa_requests: list[KGVespaChunkUpdateRequest] = []
for chunk in chunks:
doc_chunk_id = get_uuid_from_chunk_info(
document_id=document_id,
chunk_id=chunk["fields"]["chunk_id"],
tenant_id=tenant_id,
large_chunk_id=None,
)
vespa_requests.append(
KGVespaChunkUpdateRequest(
document_id=document_id,
chunk_id=chunk["fields"]["chunk_id"],
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=vespa_index.index_name)}/{doc_chunk_id}",
update_request=reset_update_dict,
)
)
with vespa_index.httpx_client_context as httpx_client:
vespa_index._apply_kg_chunk_updates_batched(vespa_requests, httpx_client)
def reset_vespa_kg_index(
tenant_id: str, index_name: str, source_name: str | None = None
) -> None:
"""
Reset the kg info in vespa for all documents of a given source name,
or all documents from kg grounded sources if source_name is None.
"""
logger.info(
f"Resetting kg vespa index {index_name} for tenant {tenant_id}, "
f"source: {source_name if source_name else 'all'}"
)
# Get all documents that need a vespa reset
with get_session_with_current_tenant() as db_session:
if source_name:
# get all connectors of the given source name
kg_connectors = [
connector.id
for connector in db_session.query(Connector)
.filter(Connector.source == DocumentSource(source_name))
.all()
]
else:
# get all connectors that have kg enabled
kg_sources = [
DocumentSource(et.grounded_source_name)
for et in db_session.query(KGEntityType)
.filter(
KGEntityType.grounded_source_name.is_not(None),
KGEntityType.active.is_(True),
)
.distinct()
.all()
]
kg_connectors = [
connector.id
for connector in db_session.query(Connector)
.filter(Connector.source.in_(kg_sources))
.all()
]
# Get all the documents for the given connectors
document_ids = [
cc_pair.id
for cc_pair in db_session.query(DocumentByConnectorCredentialPair)
.filter(DocumentByConnectorCredentialPair.connector_id.in_(kg_connectors))
.all()
]
# Reset the kg fields
for document_id in document_ids:
_reset_vespa_for_doc(document_id, tenant_id, index_name)

View File

@@ -0,0 +1,23 @@
from typing import List
import numpy as np
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.search_settings import get_current_search_settings
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import EmbedTextType
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
def encode_string_batch(strings: List[str]) -> np.ndarray:
with get_session_with_current_tenant() as db_session:
current_search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=current_search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
# Get embeddings while session is still open
embedding = model.encode(strings, text_type=EmbedTextType.QUERY)
return np.array(embedding)

View File

@@ -0,0 +1,443 @@
import re
from collections import defaultdict
from typing import Dict
from onyx.configs.constants import OnyxCallTypes
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import get_kg_entity_by_document
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.kg_config import KGConfigSettings
from onyx.db.models import Document
from onyx.kg.models import KGChunkFormat
from onyx.kg.models import KGClassificationContent
from onyx.kg.models import (
KGDocumentClassificationPrompt,
)
from onyx.kg.models import KGDocumentEntitiesRelationshipsAttributes
from onyx.kg.models import KGEnhancedDocumentMetadata
from onyx.kg.utils.formatting_utils import generalize_entities
from onyx.kg.utils.formatting_utils import kg_email_processing
from onyx.kg.utils.formatting_utils import make_entity_id
from onyx.kg.utils.formatting_utils import make_relationship_id
from onyx.prompts.kg_prompts import CALL_CHUNK_PREPROCESSING_PROMPT
from onyx.prompts.kg_prompts import CALL_DOCUMENT_CLASSIFICATION_PROMPT
from onyx.prompts.kg_prompts import GENERAL_CHUNK_PREPROCESSING_PROMPT
def _update_implied_entities_relationships(
kg_core_document_id_name: str,
owner_list: list[str],
implied_entities: set[str],
implied_relationships: set[str],
company_participant_emails: set[str],
account_participant_emails: set[str],
relationship_type: str,
kg_config_settings: KGConfigSettings,
converted_relationships_to_attributes: dict[str, list[str]],
) -> tuple[set[str], set[str], set[str], set[str], dict[str, list[str]]]:
for owner in owner_list or []:
if is_email(owner):
(
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
) = kg_process_person(
owner,
kg_core_document_id_name,
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
relationship_type,
kg_config_settings,
)
else:
converted_relationships_to_attributes[relationship_type].append(owner)
return (
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
converted_relationships_to_attributes,
)
def kg_document_entities_relationships_attribute_generation(
document: Document,
doc_metadata: KGEnhancedDocumentMetadata,
active_entities: list[str],
kg_config_settings: KGConfigSettings,
) -> KGDocumentEntitiesRelationshipsAttributes:
"""
Generate entities, relationships, and attributes for a document.
"""
# Get document entity type from the KGEnhancedDocumentMetadata
document_entity_type = doc_metadata.entity_type
assert document_entity_type is not None
# Get additional document attributes from the KGEnhancedDocumentMetadata
document_attributes = doc_metadata.document_attributes
implied_entities: set[str] = set()
implied_relationships: set[str] = (
set()
) # 'Relationships' that will be captured as KG relationships
converted_relationships_to_attributes: dict[str, list[str]] = defaultdict(
list
) # 'Relationships' that will be captured as KG entity attributes
converted_attributes_to_relationships: set[str] = (
set()
) # Attributes that should be captures as entities and then relationships (Account = ...)
company_participant_emails: set[str] = (
set()
) # Quantity needed for call processing - participants from vendor
account_participant_emails: set[str] = (
set()
) # Quantity needed for call processing - external participants
# Chunk treatment variables
document_is_from_call = document_entity_type.lower() in [
call_type.value.lower() for call_type in OnyxCallTypes
]
# Get core entity
document_id = document.id
primary_owners = document.primary_owners
secondary_owners = document.secondary_owners
with get_session_with_current_tenant() as db_session:
kg_core_document = get_kg_entity_by_document(db_session, document_id)
if kg_core_document:
kg_core_document_id_name = kg_core_document.id_name
else:
kg_core_document_id_name = make_entity_id(document_entity_type, document_id)
# Get implied entities and relationships from primary/secondary owners
if document_is_from_call:
(
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
converted_relationships_to_attributes,
) = _update_implied_entities_relationships(
kg_core_document_id_name,
owner_list=(primary_owners or []) + (secondary_owners or []),
implied_entities=implied_entities,
implied_relationships=implied_relationships,
company_participant_emails=company_participant_emails,
account_participant_emails=account_participant_emails,
relationship_type="participates_in",
kg_config_settings=kg_config_settings,
converted_relationships_to_attributes=converted_relationships_to_attributes,
)
else:
(
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
converted_relationships_to_attributes,
) = _update_implied_entities_relationships(
kg_core_document_id_name,
owner_list=primary_owners or [],
implied_entities=implied_entities,
implied_relationships=implied_relationships,
company_participant_emails=company_participant_emails,
account_participant_emails=account_participant_emails,
relationship_type="leads",
kg_config_settings=kg_config_settings,
converted_relationships_to_attributes=converted_relationships_to_attributes,
)
(
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
converted_relationships_to_attributes,
) = _update_implied_entities_relationships(
kg_core_document_id_name,
owner_list=secondary_owners or [],
implied_entities=implied_entities,
implied_relationships=implied_relationships,
company_participant_emails=company_participant_emails,
account_participant_emails=account_participant_emails,
relationship_type="participates_in",
kg_config_settings=kg_config_settings,
converted_relationships_to_attributes=converted_relationships_to_attributes,
)
if document_attributes is not None:
cleaned_document_attributes = document_attributes.copy()
for attribute, value in document_attributes.items():
if attribute.lower() in [x.lower() for x in active_entities]:
converted_attributes_to_relationships.add(attribute)
if isinstance(value, str):
implied_entity = make_entity_id(attribute, value)
implied_entities.add(implied_entity)
implied_relationships.add(
make_relationship_id(
implied_entity,
f"is_{attribute}_of",
kg_core_document_id_name,
)
)
implied_entity = make_entity_id(attribute, "*")
implied_entities.add(implied_entity)
implied_relationships.add(
make_relationship_id(
implied_entity,
f"is_{attribute}_of",
kg_core_document_id_name,
)
)
implied_relationships.add(
make_relationship_id(
implied_entity,
f"is_{attribute}_of",
make_entity_id(document_entity_type, "*"),
)
)
implied_entity = make_entity_id(attribute, value)
implied_entities.add(implied_entity)
implied_relationships.add(
make_relationship_id(
implied_entity,
f"is_{attribute}_of",
make_entity_id(document_entity_type, "*"),
)
)
cleaned_document_attributes.pop(attribute)
elif isinstance(value, list):
for item in value:
implied_entity = make_entity_id(attribute, item)
implied_entities.add(implied_entity)
implied_relationships.add(
make_relationship_id(
implied_entity,
f"is_{attribute}_of",
kg_core_document_id_name,
)
)
cleaned_document_attributes.pop(attribute)
if attribute.lower().endswith("_id") or attribute.endswith("Id"):
cleaned_document_attributes.pop(attribute)
else:
cleaned_document_attributes = None
return KGDocumentEntitiesRelationshipsAttributes(
kg_core_document_id_name=kg_core_document_id_name,
implied_entities=implied_entities,
implied_relationships=implied_relationships,
company_participant_emails=company_participant_emails,
account_participant_emails=account_participant_emails,
converted_relationships_to_attributes=converted_relationships_to_attributes,
converted_attributes_to_relationships=converted_attributes_to_relationships,
document_attributes=cleaned_document_attributes,
)
def _prepare_llm_document_content_call(
document_classification_content: KGClassificationContent,
category_list: str,
category_definition_string: str,
kg_config_settings: KGConfigSettings,
) -> KGDocumentClassificationPrompt:
"""
Calls - prepare prompt for the LLM classification.
"""
prompt = CALL_DOCUMENT_CLASSIFICATION_PROMPT.format(
beginning_of_call_content=document_classification_content.classification_content,
category_list=category_list,
category_options=category_definition_string,
vendor=kg_config_settings.KG_VENDOR,
)
return KGDocumentClassificationPrompt(
llm_prompt=prompt,
)
def kg_process_person(
person: str,
core_document_id_name: str,
implied_entities: set[str],
implied_relationships: set[str],
company_participant_emails: set[str],
account_participant_emails: set[str],
relationship_type: str,
kg_config_settings: KGConfigSettings,
) -> tuple[set[str], set[str], set[str], set[str]]:
"""
Process a single owner and return updated sets with entities and relationships.
Returns:
tuple containing (implied_entities, implied_relationships, company_participant_emails, account_participant_emails)
"""
with get_session_with_current_tenant() as db_session:
kg_config_settings = get_kg_config_settings(db_session)
if not kg_config_settings.KG_ENABLED:
raise ValueError("KG is not enabled")
assert isinstance(kg_config_settings.KG_IGNORE_EMAIL_DOMAINS, list)
kg_person = kg_email_processing(person, kg_config_settings)
if any(
domain.lower() in kg_person.company.lower()
for domain in kg_config_settings.KG_IGNORE_EMAIL_DOMAINS
):
return (
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
)
if kg_person.employee:
company_participant_emails = company_participant_emails | {
f"{kg_person.name} -- ({kg_person.company})"
}
if kg_person.name not in implied_entities:
target_general = list(generalize_entities([core_document_id_name]))[0]
employee_entity = make_entity_id("EMPLOYEE", kg_person.name)
employee_general = make_entity_id("EMPLOYEE", "*")
implied_entities.add(employee_entity)
implied_relationships |= {
make_relationship_id(
employee_entity, relationship_type, core_document_id_name
),
make_relationship_id(
employee_entity, relationship_type, target_general
),
make_relationship_id(
employee_general, relationship_type, core_document_id_name
),
make_relationship_id(
employee_general, relationship_type, target_general
),
}
if kg_person.company not in implied_entities:
company_entity = make_entity_id("VENDOR", kg_person.company)
implied_entities.add(company_entity)
implied_relationships |= {
make_relationship_id(
company_entity, relationship_type, core_document_id_name
),
make_relationship_id(
company_entity, relationship_type, target_general
),
}
else:
account_participant_emails = account_participant_emails | {
f"{kg_person.name} -- ({kg_person.company})"
}
if kg_person.company not in implied_entities:
account_entity = make_entity_id("ACCOUNT", kg_person.company)
account_general = make_entity_id("ACCOUNT", "*")
target_general = list(generalize_entities([core_document_id_name]))[0]
implied_entities |= {account_entity, account_general}
implied_relationships |= {
make_relationship_id(
account_entity, relationship_type, core_document_id_name
),
make_relationship_id(
account_general, relationship_type, core_document_id_name
),
make_relationship_id(account_entity, relationship_type, target_general),
make_relationship_id(
account_general, relationship_type, target_general
),
}
return (
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
)
def prepare_llm_content_extraction(
chunk: KGChunkFormat,
company_participant_emails: set[str],
account_participant_emails: set[str],
kg_config_settings: KGConfigSettings,
) -> str:
chunk_is_from_call = chunk.source_type.lower() in [
call_type.value.lower() for call_type in OnyxCallTypes
]
if chunk_is_from_call:
llm_context = CALL_CHUNK_PREPROCESSING_PROMPT.format(
participant_string=company_participant_emails,
account_participant_string=account_participant_emails,
vendor=kg_config_settings.KG_VENDOR,
content=chunk.content,
)
else:
llm_context = GENERAL_CHUNK_PREPROCESSING_PROMPT.format(
content=chunk.content,
vendor=kg_config_settings.KG_VENDOR,
)
return llm_context
def prepare_llm_document_content(
document_classification_content: KGClassificationContent,
category_list: str,
category_definitions: dict[str, Dict[str, str | bool]],
kg_config_settings: KGConfigSettings,
) -> KGDocumentClassificationPrompt:
"""
Prepare the content for the extraction classification.
"""
category_definition_string = ""
for category, category_data in category_definitions.items():
category_definition_string += f"{category}: {category_data['description']}\n"
if document_classification_content.source_type.lower() in [
call_type.value.lower() for call_type in OnyxCallTypes
]:
return _prepare_llm_document_content_call(
document_classification_content,
category_list,
category_definition_string,
kg_config_settings,
)
else:
return KGDocumentClassificationPrompt(
llm_prompt=None,
)
def is_email(email: str) -> bool:
"""
Check if a string is a valid email address.
"""
return re.match(r"[^@]+@[^@]+\.[^@]+", email) is not None

View File

@@ -0,0 +1,182 @@
from collections import defaultdict
from onyx.db.kg_config import KGConfigSettings
from onyx.kg.models import KGAggregatedExtractions
from onyx.kg.models import KGPerson
def format_entity_id(entity_id_name: str) -> str:
return make_entity_id(*split_entity_id(entity_id_name))
def make_entity_id(entity_type: str, entity_name: str) -> str:
return f"{entity_type.upper()}::{entity_name.lower()}"
def split_entity_id(entity_id_name: str) -> list[str]:
return entity_id_name.split("::")
def get_entity_type(entity_id_name: str) -> str:
return entity_id_name.split("::", 1)[0].upper()
def format_entity_id_for_models(entity_id_name: str) -> str:
entity_split = entity_id_name.split("::")
if len(entity_split) == 2:
entity_type, entity_name = entity_split
separator = "::"
elif len(entity_split) > 2:
raise ValueError(f"Entity {entity_id_name} is not in the correct format")
else:
entity_name = entity_id_name
separator = entity_type = ""
formatted_entity_type = entity_type.strip().upper()
formatted_entity_name = (
entity_name.strip().replace('"', "").replace("'", "").title()
)
return f"{formatted_entity_type}{separator}{formatted_entity_name}"
def format_relationship_id(relationship_id_name: str) -> str:
return make_relationship_id(*split_relationship_id(relationship_id_name))
def make_relationship_id(
source_node: str, relationship_type: str, target_node: str
) -> str:
return (
f"{format_entity_id(source_node)}__"
f"{relationship_type.lower()}__"
f"{format_entity_id(target_node)}"
)
def split_relationship_id(relationship_id_name: str) -> list[str]:
return relationship_id_name.split("__")
def format_relationship_type_id(relationship_type_id_name: str) -> str:
return make_relationship_type_id(
*split_relationship_type_id(relationship_type_id_name)
)
def make_relationship_type_id(
source_node_type: str, relationship_type: str, target_node_type: str
) -> str:
return (
f"{source_node_type.upper()}__"
f"{relationship_type.lower()}__"
f"{target_node_type.upper()}"
)
def split_relationship_type_id(relationship_type_id_name: str) -> list[str]:
return relationship_type_id_name.split("__")
def extract_relationship_type_id(relationship_id_name: str) -> str:
source_node, relationship_type, target_node = split_relationship_id(
relationship_id_name
)
return make_relationship_type_id(
get_entity_type(source_node), relationship_type, get_entity_type(target_node)
)
def aggregate_kg_extractions(
connector_aggregated_kg_extractions_list: list[KGAggregatedExtractions],
) -> KGAggregatedExtractions:
aggregated_kg_extractions = KGAggregatedExtractions(
grounded_entities_document_ids=defaultdict(str),
entities=defaultdict(int),
relationships=defaultdict(lambda: defaultdict(int)),
terms=defaultdict(int),
attributes=defaultdict(dict),
)
for connector_aggregated_kg_extractions in connector_aggregated_kg_extractions_list:
for (
grounded_entity,
document_id,
) in connector_aggregated_kg_extractions.grounded_entities_document_ids.items():
aggregated_kg_extractions.grounded_entities_document_ids[
grounded_entity
] = document_id
for entity, count in connector_aggregated_kg_extractions.entities.items():
if entity not in aggregated_kg_extractions.entities:
aggregated_kg_extractions.entities[entity] = count
else:
aggregated_kg_extractions.entities[entity] += count
for (
relationship,
relationship_data,
) in connector_aggregated_kg_extractions.relationships.items():
for source_document_id, count in relationship_data.items():
if relationship not in aggregated_kg_extractions.relationships:
aggregated_kg_extractions.relationships[relationship] = defaultdict(
int
)
aggregated_kg_extractions.relationships[relationship][
source_document_id
] += count
for term, count in connector_aggregated_kg_extractions.terms.items():
if term not in aggregated_kg_extractions.terms:
aggregated_kg_extractions.terms[term] = count
else:
aggregated_kg_extractions.terms[term] += count
return aggregated_kg_extractions
def kg_email_processing(email: str, kg_config_settings: KGConfigSettings) -> KGPerson:
"""
Process the email.
"""
name, company_domain = email.split("@")
assert isinstance(company_domain, str)
assert isinstance(kg_config_settings.KG_VENDOR_DOMAINS, list)
assert isinstance(kg_config_settings.KG_VENDOR, str)
employee = any(
domain in company_domain for domain in kg_config_settings.KG_VENDOR_DOMAINS
)
if employee:
company = kg_config_settings.KG_VENDOR
else:
company = company_domain.capitalize()
return KGPerson(name=name, company=company, employee=employee)
def generalize_entities(entities: list[str]) -> set[str]:
"""
Generalize entities to their superclass.
"""
return {make_entity_id(get_entity_type(entity), "*") for entity in entities}
def generalize_relationships(relationships: list[str]) -> set[str]:
"""
Generalize relationships to their superclass.
"""
generalized_relationships: set[str] = set()
for relationship in relationships:
assert (
len(relationship.split("__")) == 3
), "Relationship is not in the correct format"
source_entity, relationship_type, target_entity = split_relationship_id(
relationship
)
source_general = make_entity_id(get_entity_type(source_entity), "*")
target_general = make_entity_id(get_entity_type(target_entity), "*")
generalized_relationships |= {
make_relationship_id(source_general, relationship_type, target_entity),
make_relationship_id(source_entity, relationship_type, target_general),
make_relationship_id(source_general, relationship_type, target_general),
}
return generalized_relationships

View File

@@ -0,0 +1,253 @@
import json
from collections.abc import Generator
from onyx.configs.constants import OnyxCallTypes
from onyx.db.kg_config import KGConfigSettings
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
from onyx.document_index.vespa.index import IndexFilters
from onyx.kg.models import KGChunkFormat
from onyx.kg.models import KGClassificationContent
from onyx.kg.utils.formatting_utils import kg_email_processing
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_document_classification_content_for_kg_processing(
document_ids: list[str],
source: str,
index_name: str,
kg_config_settings: KGConfigSettings,
batch_size: int = 8,
num_classification_chunks: int = 3,
entity_type: str | None = None,
) -> Generator[list[KGClassificationContent], None, None]:
"""
Generates the content used for initial classification of a document from
the first num_classification_chunks chunks.
"""
classification_content_list: list[KGClassificationContent] = []
for i in range(0, len(document_ids), batch_size):
batch_document_ids = document_ids[i : i + batch_size]
for document_id in batch_document_ids:
# ... existing code for getting chunks and processing ...
first_num_classification_chunks: list[dict] = _get_chunks_via_visit_api(
chunk_request=VespaChunkRequest(
document_id=document_id,
max_chunk_ind=num_classification_chunks - 1,
min_chunk_ind=0,
),
index_name=index_name,
filters=IndexFilters(access_control_list=None),
field_names=[
"document_id",
"chunk_id",
"title",
"content",
"metadata",
"source_type",
"primary_owners",
"secondary_owners",
],
get_large_chunks=False,
)
if len(first_num_classification_chunks) == 0:
continue
first_num_classification_chunks = sorted(
first_num_classification_chunks, key=lambda x: x["fields"]["chunk_id"]
)[:num_classification_chunks]
classification_content = _get_classification_content_from_chunks(
first_num_classification_chunks,
kg_config_settings,
)
metadata = first_num_classification_chunks[0]["fields"]["metadata"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
assert isinstance(metadata, dict) or metadata is None
classification_content_list.append(
KGClassificationContent(
document_id=document_id,
classification_content=classification_content,
source_type=first_num_classification_chunks[0]["fields"][
"source_type"
],
source_metadata=metadata,
entity_type=entity_type,
)
)
# Yield the batch of classification content
if classification_content_list:
yield classification_content_list
classification_content_list = []
# Yield any remaining items
if classification_content_list:
yield classification_content_list
def get_document_chunks_for_kg_processing(
document_id: str,
deep_extraction: bool,
index_name: str,
tenant_id: str,
batch_size: int = 8,
) -> Generator[list[KGChunkFormat], None, None]:
"""
Retrieves chunks from Vespa for the given document IDs and converts them to KGChunks.
Args:
document_id (str): ID of the document to fetch chunks for
deep_extraction (bool): Whether to perform deep extraction
index_name (str): Name of the Vespa index
tenant_id (str): ID of the tenant
batch_size (int): Number of chunks to fetch per batch
Yields:
list[KGChunk]: Batches of chunks ready for KG processing
"""
current_batch: list[KGChunkFormat] = []
# get all chunks for the document
chunks = _get_chunks_via_visit_api(
chunk_request=VespaChunkRequest(document_id=document_id),
index_name=index_name,
filters=IndexFilters(access_control_list=None, tenant_id=tenant_id),
field_names=[
"document_id",
"chunk_id",
"title",
"content",
"metadata",
"primary_owners",
"secondary_owners",
"source_type",
"kg_entities",
"kg_relationships",
"kg_terms",
],
get_large_chunks=False,
)
# Convert Vespa chunks to KGChunks
# kg_chunks: list[KGChunkFormat] = []
for i, chunk in enumerate(chunks):
fields = chunk["fields"]
if isinstance(fields.get("metadata", {}), str):
fields["metadata"] = json.loads(fields["metadata"])
current_batch.append(
KGChunkFormat(
connector_id=None, # We may need to adjust this
document_id=fields.get("document_id"),
chunk_id=fields.get("chunk_id"),
primary_owners=fields.get("primary_owners", []),
secondary_owners=fields.get("secondary_owners", []),
source_type=fields.get("source_type", ""),
title=fields.get("title", ""),
content=fields.get("content", ""),
metadata=fields.get("metadata", {}),
entities=fields.get("kg_entities", {}),
relationships=fields.get("kg_relationships", {}),
terms=fields.get("kg_terms", {}),
deep_extraction=deep_extraction,
)
)
if len(current_batch) >= batch_size:
yield current_batch
current_batch = []
# Yield any remaining chunks
if current_batch:
yield current_batch
def _get_classification_content_from_call_chunks(
first_num_classification_chunks: list[dict],
kg_config_settings: KGConfigSettings,
) -> str:
"""
Creates a KGClassificationContent object from a list of call chunks.
"""
assert isinstance(kg_config_settings.KG_IGNORE_EMAIL_DOMAINS, list)
primary_owners = first_num_classification_chunks[0]["fields"].get(
"primary_owners", []
)
secondary_owners = first_num_classification_chunks[0]["fields"].get(
"secondary_owners", []
)
company_participant_emails = set()
account_participant_emails = set()
for owner in primary_owners + secondary_owners:
kg_owner = kg_email_processing(owner, kg_config_settings)
if any(
domain.lower() in kg_owner.company.lower()
for domain in kg_config_settings.KG_IGNORE_EMAIL_DOMAINS
):
continue
if kg_owner.employee:
company_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
else:
account_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
participant_string = "\n - " + "\n - ".join(company_participant_emails)
account_participant_string = "\n - " + "\n - ".join(account_participant_emails)
title_string = first_num_classification_chunks[0]["fields"]["title"]
content_string = "\n".join(
[
chunk_content["fields"]["content"]
for chunk_content in first_num_classification_chunks
]
)
classification_content = f"{title_string}\n\nVendor Participants:\n{participant_string}\n\n\
Other Participants:\n{account_participant_string}\n\nBeginning of Call:\n{content_string}"
return classification_content
def _get_classification_content_from_chunks(
first_num_classification_chunks: list[dict],
kg_config_settings: KGConfigSettings,
) -> str:
"""
Creates a KGClassificationContent object from a list of chunks.
"""
source_type = first_num_classification_chunks[0]["fields"]["source_type"]
if source_type.lower() in [call_type.value.lower() for call_type in OnyxCallTypes]:
classification_content = _get_classification_content_from_call_chunks(
first_num_classification_chunks,
kg_config_settings,
)
else:
classification_content = (
first_num_classification_chunks[0]["fields"]["title"]
+ "\n"
+ "\n".join(
[
chunk_content["fields"]["content"]
for chunk_content in first_num_classification_chunks
]
)
)
return classification_content

View File

@@ -41,6 +41,8 @@ from onyx.configs.app_configs import OAUTH_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from onyx.configs.app_configs import POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW
from onyx.configs.app_configs import POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE
from onyx.configs.app_configs import SYSTEM_RECURSION_LIMIT
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
@@ -223,6 +225,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
)
SqlEngine.get_engine()
SqlEngine.init_readonly_engine(
pool_size=POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE,
max_overflow=POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW,
)
verify_auth = fetch_versioned_implementation(
"onyx.auth.users", "verify_auth_setting"
)

File diff suppressed because it is too large Load Diff

View File

@@ -78,6 +78,11 @@ class SearchToolOverrideKwargs(BaseModel):
document_sources: list[DocumentSource] | None = None
time_cutoff: datetime | None = None
expanded_queries: QueryExpansions | None = None
kg_entities: list[str] | None = None
kg_relationships: list[str] | None = None
kg_terms: list[str] | None = None
kg_sources: list[str] | None = None
kg_chunk_id_zero_only: bool | None = False
class Config:
arbitrary_types_allowed = True

View File

@@ -296,6 +296,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
document_sources = None
time_cutoff = None
expanded_queries = None
kg_entities = None
kg_relationships = None
kg_terms = None
kg_sources = None
kg_chunk_id_zero_only = False
if override_kwargs:
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
alternate_db_session = override_kwargs.alternate_db_session
@@ -308,6 +313,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
document_sources = override_kwargs.document_sources
time_cutoff = override_kwargs.time_cutoff
expanded_queries = override_kwargs.expanded_queries
kg_entities = override_kwargs.kg_entities
kg_relationships = override_kwargs.kg_relationships
kg_terms = override_kwargs.kg_terms
kg_sources = override_kwargs.kg_sources
kg_chunk_id_zero_only = override_kwargs.kg_chunk_id_zero_only or False
if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
@@ -331,6 +341,19 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
# Overwrite time-cutoff should supercede existing time-cutoff, even if defined
retrieval_options.filters.time_cutoff = time_cutoff
retrieval_options = retrieval_options or RetrievalDetails()
retrieval_options.filters = retrieval_options.filters or BaseFilters()
if kg_entities:
retrieval_options.filters.kg_entities = kg_entities
if kg_relationships:
retrieval_options.filters.kg_relationships = kg_relationships
if kg_terms:
retrieval_options.filters.kg_terms = kg_terms
if kg_sources:
retrieval_options.filters.kg_sources = kg_sources
if kg_chunk_id_zero_only:
retrieval_options.filters.kg_chunk_id_zero_only = kg_chunk_id_zero_only
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=query,

View File

@@ -80,6 +80,7 @@ slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.15
starlette==0.46.1
supervisor==4.2.5
RapidFuzz==3.13.0
tiktoken==0.7.0
timeago==1.0.16
transformers==4.49.0

View File

@@ -101,6 +101,20 @@ autorestart=true
startsecs=10
stopasgroup=true
[program:celery_worker_kg_processing]
command=celery -A onyx.background.celery.versioned_apps.kg_processing worker
--loglevel=INFO
--hostname=kg_processing@%%n
-Q kg_processing
stdout_logfile=/var/log/celery_worker_kg_processing.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
# Job scheduler for periodic tasks
[program:celery_beat]
command=celery -A onyx.background.celery.versioned_apps.beat beat

View File

@@ -155,7 +155,7 @@
"Email: fiannellib46@marriott.com\nIsDeleted: false\nLastName: Iannelli\nIsEmailBounced: false\nFirstName: Felicio\nIsPriorityRecord: false\nCleanStatus: Pending"
],
"semantic_identifier": "Voonder",
"metadata": {},
"metadata": {"object_type": "Account"},
"primary_owners": {"email": "hagen@danswer.ai", "first_name": "Hagen", "last_name": "oneill"},
"secondary_owners": null,
"title": null

View File

@@ -50,15 +50,21 @@ def answer_instance(
"onyx.chat.answer.fast_gpu_status_request",
return_value=True,
)
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config, mocker)
def _answer_fixture_impl(
mock_llm: LLM,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
mocker: MockerFixture,
rerank_settings: RerankingDetails | None = None,
) -> Answer:
mock_db_session = Mock(spec=Session)
mock_query = Mock()
mock_query.all.return_value = []
mock_db_session.query.return_value = mock_query
return Answer(
prompt_builder=AnswerPromptBuilder(
user_message=default_build_user_message(
@@ -73,7 +79,7 @@ def _answer_fixture_impl(
raw_user_query=QUERY,
raw_user_uploaded_files=[],
),
db_session=Mock(spec=Session),
db_session=mock_db_session,
answer_style_config=answer_style_config,
llm=mock_llm,
fast_llm=mock_llm,
@@ -404,7 +410,11 @@ def test_no_slow_reranking(
)
)
answer_instance = _answer_fixture_impl(
mock_llm, answer_style_config, prompt_config, rerank_settings=rerank_settings
mock_llm,
answer_style_config,
prompt_config,
mocker,
rerank_settings=rerank_settings,
)
assert answer_instance.graph_config.inputs.rerank_settings == rerank_settings

View File

@@ -42,6 +42,7 @@ def test_skip_gen_ai_answer_generation_flag(
"onyx.chat.answer.fast_gpu_status_request",
return_value=True,
)
question = config["question"]
skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"]
@@ -58,8 +59,14 @@ def test_skip_gen_ai_answer_generation_flag(
mock_llm.stream = Mock()
mock_llm.stream.return_value = [Mock()]
# Set up the mock database session
mock_db_session = Mock(spec=Session)
mock_query = Mock()
mock_db_session.query.return_value = mock_query
mock_query.all.return_value = [] # Return empty list for KGConfig query
answer = Answer(
db_session=Mock(spec=Session),
db_session=mock_db_session,
answer_style_config=answer_style_config,
llm=mock_llm,
fast_llm=mock_llm,

View File

@@ -747,11 +747,17 @@ def test_salesforce_sqlite() -> None:
sf_db.apply_schema()
_create_csv_with_example_data(sf_db)
_test_query(sf_db)
_test_upsert(sf_db)
_test_relationships(sf_db)
_test_account_with_children(sf_db)
_test_relationship_updates(sf_db)
_test_get_affected_parent_ids(sf_db)
sf_db.close()

View File

@@ -183,6 +183,8 @@ services:
- POSTGRES_HOST=relational_db
- POSTGRES_USER=${POSTGRES_USER:-}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
- DB_READONLY_USER=${DB_READONLY_USER:-}
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
- POSTGRES_DB=${POSTGRES_DB:-}
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
- POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-}
@@ -384,6 +386,8 @@ services:
environment:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
- DB_READONLY_USER=${DB_READONLY_USER:-}
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
ports:
- "5432:5432"
volumes:

View File

@@ -148,6 +148,8 @@ services:
- POSTGRES_HOST=relational_db
- POSTGRES_USER=${POSTGRES_USER:-}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
- DB_READONLY_USER=${DB_READONLY_USER:-}
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
- POSTGRES_DB=${POSTGRES_DB:-}
- VESPA_HOST=index
- REDIS_HOST=cache
@@ -330,6 +332,8 @@ services:
environment:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
- DB_READONLY_USER=${DB_READONLY_USER:-}
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
ports:
- "5432:5432"
volumes:

View File

@@ -166,6 +166,8 @@ services:
- POSTGRES_HOST=relational_db
- POSTGRES_USER=${POSTGRES_USER:-}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
- DB_READONLY_USER=${DB_READONLY_USER:-}
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
- POSTGRES_DB=${POSTGRES_DB:-}
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
- VESPA_HOST=index
@@ -357,6 +359,8 @@ services:
environment:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
- DB_READONLY_USER=${DB_READONLY_USER:-}
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
ports:
- "5432:5432"
volumes:

View File

@@ -154,6 +154,8 @@ services:
environment:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
- DB_READONLY_USER=${DB_READONLY_USER:-}
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
ports:
- "5432"
volumes:

View File

@@ -55,3 +55,10 @@ SESSION_EXPIRE_TIME_SECONDS=604800
# Default values here are what Postgres uses by default, feel free to change.
POSTGRES_USER=postgres
POSTGRES_PASSWORD=password
# Default values here for the read-only user for the knowledge graph and other future read-only purposes.
# Please change password!
DB_READONLY_USER=db_readonly_user
DB_READONLY_PASSWORD=password

View File

@@ -40,6 +40,16 @@ spec:
secretKeyRef:
name: onyx-secrets
key: postgres_password
- name: DB_READONLY_USER
valueFrom:
secretKeyRef:
name: onyx-secrets
key: DB_READONLY_user
- name: DB_READONLY_PASSWORD
valueFrom:
secretKeyRef:
name: onyx-secrets
key: DB_READONLY_password
args: ["-c", "max_connections=250"]
ports:
- containerPort: 5432

View File

@@ -532,7 +532,7 @@ export function AssistantEditor({
// if disable_retrieval is set, set num_chunks to 0
// to tell the backend to not fetch any documents
const numChunks = searchToolEnabled ? values.num_chunks || 10 : 0;
const numChunks = searchToolEnabled ? values.num_chunks || 25 : 0;
const starterMessages = values.starter_messages
.filter(
(message: { message: string }) => message.message.trim() !== ""