Compare commits

...

67 Commits

Author SHA1 Message Date
joachim-danswer
92bb851e19 Update a3_generate_simple_sql.py 2025-05-17 10:06:33 -07:00
joachim-danswer
3951cf2f78 SQL View fix 2025-05-17 09:47:56 -07:00
joachim-danswer
9949ffd9bc read-only user replacement & cleanup 2025-05-16 09:03:19 -07:00
joachim-danswer
5574e7485f EL comments 2025-05-16 07:25:25 -07:00
joachim-danswer
99cf8dca74 Salesforce test update 2025-05-15 22:07:42 -07:00
joachim-danswer
aed1d87d89 setting updated 2025-05-15 22:07:42 -07:00
joachim-danswer
6bb9b17c1c nit 2025-05-15 22:07:42 -07:00
joachim-danswer
aaed5716f9 EL + RK (pt 1) comments 2025-05-15 22:07:42 -07:00
joachim-danswer
e302bcafaf small fixes 2025-05-15 22:07:42 -07:00
joachim-danswer
57d412c3a7 mypy fix 2025-05-15 22:07:42 -07:00
joachim-danswer
6fb0933cfb SF Connector fix & kg_stage removal for one table 2025-05-15 22:07:42 -07:00
joachim-danswer
e265b4f747 small adjustments 2025-05-15 22:07:42 -07:00
joachim-danswer
6c49736024 small changes 2025-05-15 22:07:42 -07:00
joachim-danswer
1779b65185 test update 2025-05-15 22:07:42 -07:00
joachim-danswer
148bce59d9 salesforce fix 2025-05-15 22:07:42 -07:00
joachim-danswer
8cadb57df2 test improvements 2025-05-15 22:07:42 -07:00
joachim-danswer
33a4b15ae2 env vars 2025-05-15 22:07:42 -07:00
joachim-danswer
49e1a4b782 nf 2025-05-15 22:07:42 -07:00
joachim-danswer
1f95e2291d read_only pool + misc 2025-05-15 22:07:42 -07:00
joachim-danswer
64cba0bdca test fix 2025-05-15 22:07:42 -07:00
joachim-danswer
daad2e9de0 test fix 2025-05-15 22:07:42 -07:00
joachim-danswer
69d8aae5f0 test fixes 2025-05-15 22:07:42 -07:00
joachim-danswer
34630b1947 mypy fixes 2025-05-15 22:07:42 -07:00
joachim-danswer
e65e192622 post_rebase fixes 2025-05-15 22:07:41 -07:00
joachim-danswer
713cc6b6a4 migration update 2025-05-15 22:07:41 -07:00
joachim-danswer
84de01d726 quick migration fix 2025-05-15 22:07:41 -07:00
joachim-danswer
df2bc953e8 migrations and env vars 2025-05-15 22:07:41 -07:00
joachim-danswer
cd4fd267e1 EL changes pt 2 2025-05-15 22:07:41 -07:00
joachim-danswer
078fe358bb kg-filtered search 2025-05-15 22:07:41 -07:00
joachim-danswer
2b9a8baf7a evan updates + quite a bit more 2025-05-15 22:07:41 -07:00
joachim-danswer
dbed47b4b0 base w/ salesforce 2025-05-15 22:07:41 -07:00
joachim-danswer
ffc0692f5f SF Connector update
- include account information
2025-05-15 22:07:41 -07:00
joachim-danswer
0c5db00673 initial Account/SF Connector chnges 2025-05-15 22:07:41 -07:00
joachim-danswer
b4641189a0 EL initial comments 2025-05-15 22:07:41 -07:00
joachim-danswer
c01136d661 typo 2025-05-15 22:07:41 -07:00
joachim-danswer
034224a946 kg read-only user creation as part of migration 2025-05-15 22:07:41 -07:00
joachim-danswer
a03c2ed4ff fix for missing entity attributes 2025-05-15 22:07:41 -07:00
joachim-danswer
ad966c51b4 read-only user creation as part of setup 2025-05-15 22:07:41 -07:00
joachim-danswer
d68bd41e72 extraction process fix 2025-05-15 22:07:41 -07:00
joachim-danswer
8e2417f563 fixes 2025-05-15 22:07:41 -07:00
joachim-danswer
498fc587ab more feature flag checks 2025-05-15 22:07:41 -07:00
joachim-danswer
470a4b88ae updates 2025-05-15 22:07:41 -07:00
joachim-danswer
0906125af0 fixes 2025-05-15 22:07:41 -07:00
joachim-danswer
38733965ba simplifications & cleanup 2025-05-15 22:07:41 -07:00
joachim-danswer
acf8a57798 updates 2025-05-15 22:07:41 -07:00
joachim-danswer
462592867c nits 2025-05-15 22:07:41 -07:00
joachim-danswer
0980aa4222 fixed KG extraction 2025-05-15 22:07:41 -07:00
joachim-danswer
a8b748066e progress 2025-05-15 22:07:41 -07:00
joachim-danswer
81833c8a54 new extraction 2025-05-15 22:07:41 -07:00
joachim-danswer
cd1b48acd4 base 2025-05-15 22:07:41 -07:00
joachim-danswer
1552c61dbd progress 2025-05-15 22:07:41 -07:00
joachim-danswer
5915066558 a3+ 2025-05-15 22:07:41 -07:00
joachim-danswer
3aa2b51ca4 rebase migration change 2025-05-15 22:07:41 -07:00
joachim-danswer
7bfb1a4a61 migration downgrade fix 2025-05-15 22:07:41 -07:00
joachim-danswer
dfa12e9caf dev 2025-05-15 22:07:41 -07:00
joachim-danswer
c595df7a1e focus on metadata relatonships 1 2025-05-15 22:07:41 -07:00
joachim-danswer
1b08b35262 extraction revamp 2025-05-15 22:07:41 -07:00
joachim-danswer
fe919d54da separate read_only engine 2025-05-15 22:07:41 -07:00
joachim-danswer
8259773174 updates 2025-05-15 22:07:41 -07:00
joachim-danswer
b1406d6b65 nits 2025-05-15 22:07:41 -07:00
joachim-danswer
95f5769fb6 restructuring 2025-05-15 22:07:41 -07:00
joachim-danswer
8ed105bc05 temp view creation 2025-05-15 22:07:41 -07:00
joachim-danswer
b818842a94 relationship table + query update 2025-05-15 22:07:41 -07:00
joachim-danswer
b72ec48b41 more adjustments 2025-05-15 22:07:41 -07:00
joachim-danswer
45862c941b transfer 1 - incomplete 2025-05-15 22:07:41 -07:00
joachim-danswer
b04925816c dc fix 2025-05-15 22:07:41 -07:00
joachim-danswer
c14468413a db setup 2025-05-15 22:07:41 -07:00
83 changed files with 11604 additions and 31 deletions

View File

@@ -114,6 +114,8 @@ jobs:
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
@@ -212,6 +214,8 @@ jobs:
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \

View File

@@ -152,6 +152,8 @@ jobs:
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \

View File

@@ -0,0 +1,438 @@
"""create knowlege graph tables
Revision ID: 495cb26ce93e
Revises: a7688ab35c45
Create Date: 2025-03-19 08:51:14.341989
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy import text
from onyx.configs.app_configs import DB_READONLY_USER
from onyx.configs.app_configs import DB_READONLY_PASSWORD
from shared_configs.configs import MULTI_TENANT
# revision identifiers, used by Alembic.
revision = "495cb26ce93e"
down_revision = "a7688ab35c45"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create a new permission-less user to be later used for knowledge graph queries.
# The user will later get temporary read priviledges for a specific view that will be
# ad hoc generated specific to a knowledge graph query.
#
# Note: in order for the migration to run, the DB_READONLY_USER and DB_READONLY_PASSWORD
# environment variables MUST be set. Otherwise, an exception will be raised.
if not MULTI_TENANT:
# Create read-only db user here only in single tenant mode. For multi-tenant mode,
# the user is created in the alembic_tenants migration.
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
op.execute(
text(
f"""
DO $$
BEGIN
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
-- Explicitly revoke all privileges including CONNECT
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
-- Grant only the CONNECT privilege
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)
op.create_table(
"kg_config",
sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True),
sa.Column("kg_variable_name", sa.String(), nullable=False, index=True),
sa.Column("kg_variable_values", postgresql.ARRAY(sa.String()), nullable=False),
)
op.create_table(
"kg_entity_type",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("description", sa.String(), nullable=True),
sa.Column("grounding", sa.String(), nullable=False),
sa.Column(
"attributes",
postgresql.JSONB,
nullable=False,
server_default="{}",
),
sa.Column("occurrences", sa.Integer(), nullable=True),
sa.Column("active", sa.Boolean(), nullable=False, default=False),
sa.Column("deep_extraction", sa.Boolean(), nullable=False, default=False),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.Column("grounded_source_name", sa.String(), nullable=True),
sa.Column("entity_values", postgresql.ARRAY(sa.String()), nullable=True),
sa.Column(
"clustering",
postgresql.JSONB,
nullable=False,
server_default="{}",
),
)
# Create KGRelationshipType table
op.create_table(
"kg_relationship_type",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("name", sa.String(), nullable=False, index=True),
sa.Column(
"source_entity_type_id_name", sa.String(), nullable=False, index=True
),
sa.Column(
"target_entity_type_id_name", sa.String(), nullable=False, index=True
),
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
sa.Column("occurrences", sa.Integer(), nullable=True),
sa.Column("type", sa.String(), nullable=False, index=True),
sa.Column("active", sa.Boolean(), nullable=False, default=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.Column(
"clustering",
postgresql.JSONB,
nullable=False,
server_default="{}",
),
sa.ForeignKeyConstraint(
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
),
sa.ForeignKeyConstraint(
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
),
)
# Create KGRelationshipTypeExtractionStaging table
op.create_table(
"kg_relationship_type_extraction_staging",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("name", sa.String(), nullable=False, index=True),
sa.Column(
"source_entity_type_id_name", sa.String(), nullable=False, index=True
),
sa.Column(
"target_entity_type_id_name", sa.String(), nullable=False, index=True
),
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
sa.Column("occurrences", sa.Integer(), nullable=True),
sa.Column("type", sa.String(), nullable=False, index=True),
sa.Column("active", sa.Boolean(), nullable=False, default=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.Column(
"clustering",
postgresql.JSONB,
nullable=False,
server_default="{}",
),
sa.ForeignKeyConstraint(
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
),
sa.ForeignKeyConstraint(
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
),
)
# Create KGEntity table
op.create_table(
"kg_entity",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("name", sa.String(), nullable=False, index=True),
sa.Column("sub_type", sa.String(), nullable=True, index=True),
sa.Column("document_id", sa.String(), nullable=True, index=True),
sa.Column(
"alternative_names",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
sa.Column("description", sa.String(), nullable=True),
sa.Column(
"keywords",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column("occurrences", sa.Integer(), nullable=True),
sa.Column(
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
),
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("attributes", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
sa.ForeignKeyConstraint(["document_id"], ["document.id"]),
)
op.create_index("ix_entity_type_acl", "kg_entity", ["entity_type_id_name", "acl"])
op.create_index(
"ix_entity_name_search", "kg_entity", ["name", "entity_type_id_name"]
)
# Create KGEntityExtractionStaging table
op.create_table(
"kg_entity_extraction_staging",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column("name", sa.String(), nullable=False, index=True),
sa.Column("sub_type", sa.String(), nullable=True, index=True),
sa.Column("document_id", sa.String(), nullable=True, index=True),
sa.Column(
"alternative_names",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
sa.Column("description", sa.String(), nullable=True),
sa.Column(
"keywords",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column("occurrences", sa.Integer(), nullable=True),
sa.Column(
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
),
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("attributes", postgresql.JSONB, nullable=False, server_default="{}"),
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
sa.ForeignKeyConstraint(["document_id"], ["document.id"]),
)
op.create_index(
"ix_entity_extraction_staging_acl",
"kg_entity_extraction_staging",
["entity_type_id_name", "acl"],
)
op.create_index(
"ix_entity_extraction_staging_name_search",
"kg_entity_extraction_staging",
["name", "entity_type_id_name"],
)
# Create KGRelationship table
op.create_table(
"kg_relationship",
sa.Column("id_name", sa.String(), nullable=False, index=True),
sa.Column("source_node", sa.String(), nullable=False, index=True),
sa.Column("target_node", sa.String(), nullable=False, index=True),
sa.Column("source_node_type", sa.String(), nullable=False, index=True),
sa.Column("target_node_type", sa.String(), nullable=False, index=True),
sa.Column("source_document", sa.String(), nullable=True, index=True),
sa.Column("type", sa.String(), nullable=False, index=True),
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
sa.Column("occurrences", sa.Integer(), nullable=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.ForeignKeyConstraint(["source_node"], ["kg_entity.id_name"]),
sa.ForeignKeyConstraint(["target_node"], ["kg_entity.id_name"]),
sa.ForeignKeyConstraint(["source_node_type"], ["kg_entity_type.id_name"]),
sa.ForeignKeyConstraint(["target_node_type"], ["kg_entity_type.id_name"]),
sa.ForeignKeyConstraint(["source_document"], ["document.id"]),
sa.ForeignKeyConstraint(
["relationship_type_id_name"], ["kg_relationship_type.id_name"]
),
sa.UniqueConstraint(
"source_node",
"target_node",
"type",
name="uq_kg_relationship_source_target_type",
),
sa.PrimaryKeyConstraint("id_name", "source_document"),
)
op.create_index(
"ix_kg_relationship_nodes", "kg_relationship", ["source_node", "target_node"]
)
# Create KGRelationshipExtractionStaging table
op.create_table(
"kg_relationship_extraction_staging",
sa.Column("id_name", sa.String(), nullable=False, index=True),
sa.Column("source_node", sa.String(), nullable=False, index=True),
sa.Column("target_node", sa.String(), nullable=False, index=True),
sa.Column("source_node_type", sa.String(), nullable=False, index=True),
sa.Column("target_node_type", sa.String(), nullable=False, index=True),
sa.Column("source_document", sa.String(), nullable=True, index=True),
sa.Column("type", sa.String(), nullable=False, index=True),
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
sa.Column("occurrences", sa.Integer(), nullable=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
sa.ForeignKeyConstraint(
["source_node"], ["kg_entity_extraction_staging.id_name"]
),
sa.ForeignKeyConstraint(
["target_node"], ["kg_entity_extraction_staging.id_name"]
),
sa.ForeignKeyConstraint(["source_node_type"], ["kg_entity_type.id_name"]),
sa.ForeignKeyConstraint(["target_node_type"], ["kg_entity_type.id_name"]),
sa.ForeignKeyConstraint(["source_document"], ["document.id"]),
sa.ForeignKeyConstraint(
["relationship_type_id_name"],
["kg_relationship_type_extraction_staging.id_name"],
),
sa.UniqueConstraint(
"source_node",
"target_node",
"type",
name="uq_kg_relationship_extraction_staging_source_target_type",
),
sa.PrimaryKeyConstraint("id_name", "source_document"),
)
op.create_index(
"ix_kg_relationship_extraction_staging_nodes",
"kg_relationship_extraction_staging",
["source_node", "target_node"],
)
# Create KGTerm table
op.create_table(
"kg_term",
sa.Column("id_term", sa.String(), primary_key=True, nullable=False, index=True),
sa.Column(
"entity_types",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
),
)
op.create_index("ix_search_term_entities", "kg_term", ["entity_types"])
op.create_index("ix_search_term_term", "kg_term", ["id_term"])
op.add_column(
"document",
sa.Column("kg_stage", sa.String(), nullable=True, index=True),
)
op.add_column(
"document",
sa.Column("kg_processing_time", sa.DateTime(timezone=True), nullable=True),
)
op.add_column(
"connector",
sa.Column(
"kg_processing_enabled",
sa.Boolean(),
nullable=True,
server_default="false",
),
)
def downgrade() -> None:
# Drop tables in reverse order of creation to handle dependencies
op.drop_table("kg_term")
op.drop_table("kg_relationship")
op.drop_table("kg_entity")
op.drop_table("kg_relationship_type")
op.drop_table("kg_relationship_extraction_staging")
op.drop_table("kg_relationship_type_extraction_staging")
op.drop_table("kg_entity_extraction_staging")
op.drop_table("kg_entity_type")
op.drop_column("connector", "kg_processing_enabled")
op.drop_column("document", "kg_stage")
op.drop_column("document", "kg_processing_time")
op.drop_table("kg_config")
if not MULTI_TENANT:
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
# the user is dropped in the alembic_tenants migration.
op.execute(
text(
f"""
DO $$
BEGIN
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
-- First revoke all privileges from the database
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
-- Then revoke all privileges from the public schema
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
-- Then drop the user
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)

View File

@@ -0,0 +1,72 @@
"""add_db_readonly_user
Revision ID: 3b9f09038764
Revises: 3b45e0018bf1
Create Date: 2025-05-11 11:05:11.436977
"""
from sqlalchemy import text
from alembic import op
from onyx.configs.app_configs import DB_READONLY_PASSWORD
from onyx.configs.app_configs import DB_READONLY_USER
from shared_configs.configs import MULTI_TENANT
# revision identifiers, used by Alembic.
revision = "3b9f09038764"
down_revision = "3b45e0018bf1"
branch_labels = None
depends_on = None
def upgrade() -> None:
if MULTI_TENANT:
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
# the user is created in the standard migration.
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
op.execute(
text(
f"""
DO $$
BEGIN
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
-- Explicitly revoke all privileges including CONNECT
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
-- Grant only the CONNECT privilege
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)
def downgrade() -> None:
if MULTI_TENANT:
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
# the user is dropped in the alembic_tenants migration.
op.execute(
text(
f"""
DO $$
BEGIN
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
-- First revoke all privileges from the database
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
-- Then revoke all privileges from the public schema
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
-- Then drop the user
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)

View File

@@ -35,9 +35,8 @@ def research_object_source(
datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.search_request.query
search_tool = graph_config.tooling.search_tool
object, document_source = state.object_source_combination
if search_tool is None or graph_config.inputs.search_request.persona is None:
@@ -153,7 +152,6 @@ def research_object_source(
),
)
]
# fast_llm = graph_config.tooling.fast_llm
primary_llm = graph_config.tooling.primary_llm
llm = primary_llm
# Grader

View File

@@ -73,7 +73,6 @@ def consolidate_object_research(
)
]
graph_config.tooling.primary_llm
# fast_llm = graph_config.tooling.fast_llm
primary_llm = graph_config.tooling.primary_llm
llm = primary_llm
# Grader

View File

@@ -0,0 +1,68 @@
from collections.abc import Hashable
from datetime import datetime
from enum import Enum
from langgraph.types import Send
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
class KGAnalysisPath(str, Enum):
PROCESS_KG_ONLY_ANSWERS = "process_kg_only_answers"
CONSTRUCT_DEEP_SEARCH_FILTERS = "construct_deep_search_filters"
def simple_vs_search(
state: MainState,
) -> str:
if (
state.strategy == KGAnswerStrategy.DEEP
or state.search_type == KGSearchType.SEARCH
):
return KGAnalysisPath.CONSTRUCT_DEEP_SEARCH_FILTERS.value
else:
return KGAnalysisPath.PROCESS_KG_ONLY_ANSWERS.value
def research_individual_object(
state: MainState,
) -> list[Send | Hashable] | str:
edge_start_time = datetime.now()
assert state.div_con_entities is not None
assert state.broken_down_question is not None
assert state.vespa_filter_results is not None
if (
state.search_type == KGSearchType.SQL
and state.strategy == KGAnswerStrategy.DEEP
):
return [
Send(
"process_individual_deep_search",
ResearchObjectInput(
research_nr=research_nr + 1,
entity=entity,
broken_down_question=state.broken_down_question,
vespa_filter_results=state.vespa_filter_results,
source_division=state.source_division,
source_entity_filters=state.source_filters,
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
step_results=[],
),
)
for research_nr, entity in enumerate(state.div_con_entities)
]
elif state.search_type == KGSearchType.SEARCH:
return "filtered_search"
else:
raise ValueError(
f"Invalid combination of search type: {state.search_type} and strategy: {state.strategy}"
)

View File

@@ -0,0 +1,132 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.kb_search.conditional_edges import (
research_individual_object,
)
from onyx.agents.agent_search.kb_search.conditional_edges import simple_vs_search
from onyx.agents.agent_search.kb_search.nodes.a1_extract_ert import extract_ert
from onyx.agents.agent_search.kb_search.nodes.a2_analyze import analyze
from onyx.agents.agent_search.kb_search.nodes.a3_generate_simple_sql import (
generate_simple_sql,
)
from onyx.agents.agent_search.kb_search.nodes.b1_construct_deep_search_filters import (
construct_deep_search_filters,
)
from onyx.agents.agent_search.kb_search.nodes.b2p_process_individual_deep_search import (
process_individual_deep_search,
)
from onyx.agents.agent_search.kb_search.nodes.b2s_filtered_search import filtered_search
from onyx.agents.agent_search.kb_search.nodes.b3_consolidate_individual_deep_search import (
consolidate_individual_deep_search,
)
from onyx.agents.agent_search.kb_search.nodes.c1_process_kg_only_answers import (
process_kg_only_answers,
)
from onyx.agents.agent_search.kb_search.nodes.d1_generate_answer import generate_answer
from onyx.agents.agent_search.kb_search.states import MainInput
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
def kb_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the knowledge graph search process.
"""
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
### Add nodes ###
graph.add_node(
"extract_ert",
extract_ert,
)
graph.add_node(
"generate_simple_sql",
generate_simple_sql,
)
graph.add_node(
"filtered_search",
filtered_search,
)
graph.add_node(
"analyze",
analyze,
)
graph.add_node(
"generate_answer",
generate_answer,
)
graph.add_node(
"construct_deep_search_filters",
construct_deep_search_filters,
)
graph.add_node(
"process_individual_deep_search",
process_individual_deep_search,
)
graph.add_node(
"consoldidate_individual_deep_search",
consolidate_individual_deep_search,
)
graph.add_node("process_kg_only_answers", process_kg_only_answers)
### Add edges ###
graph.add_edge(start_key=START, end_key="extract_ert")
graph.add_edge(
start_key="extract_ert",
end_key="analyze",
)
graph.add_edge(
start_key="analyze",
end_key="generate_simple_sql",
)
graph.add_conditional_edges("generate_simple_sql", simple_vs_search)
graph.add_edge(start_key="process_kg_only_answers", end_key="generate_answer")
graph.add_conditional_edges(
source="construct_deep_search_filters",
path=research_individual_object,
path_map=["process_individual_deep_search", "filtered_search"],
)
graph.add_edge(
start_key="process_individual_deep_search",
end_key="consoldidate_individual_deep_search",
)
graph.add_edge(
start_key="consoldidate_individual_deep_search", end_key="generate_answer"
)
graph.add_edge(
start_key="filtered_search",
end_key="generate_answer",
)
graph.add_edge(
start_key="generate_answer",
end_key=END,
)
return graph

View File

@@ -0,0 +1,461 @@
import re
from time import sleep
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
from onyx.agents.agent_search.kb_search.models import KGExpandedGraphObjects
from onyx.agents.agent_search.kb_search.states import SubQuestionAnswerResults
from onyx.agents.agent_search.kb_search.step_definitions import STEP_DESCRIPTIONS
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import LlmDoc
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.db.document import get_kg_doc_info_for_entity_name
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import get_document_id_for_entity
from onyx.db.entity_type import get_entity_types
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _check_entities_disconnected(
current_entities: list[str], current_relationships: list[str]
) -> bool:
"""
Check if all entities in current_entities are disconnected via the given relationships.
Relationships are in the format: source_entity__relationship_name__target_entity
Args:
current_entities: List of entity IDs to check connectivity for
current_relationships: List of relationships in format source__relationship__target
Returns:
bool: True if all entities are disconnected, False otherwise
"""
if not current_entities:
return True
# Create a graph representation using adjacency list
graph: dict[str, set[str]] = {entity: set() for entity in current_entities}
# Build the graph from relationships
for relationship in current_relationships:
try:
source, _, target = relationship.split("__")
if source in graph and target in graph:
graph[source].add(target)
# Add reverse edge to capture that we do also have a relationship in the other direction,
# albeit not quite the same one.
graph[target].add(source)
except ValueError:
raise ValueError(f"Invalid relationship format: {relationship}")
# Use BFS to check if all entities are connected
visited: set[str] = set()
start_entity = current_entities[0]
def _bfs(start: str) -> None:
queue = [start]
visited.add(start)
while queue:
current = queue.pop(0)
for neighbor in graph[current]:
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
# Start BFS from the first entity
_bfs(start_entity)
logger.debug(f"Number of visited entities: {len(visited)}")
# Check if all current_entities are in visited
return not all(entity in visited for entity in current_entities)
def create_minimal_connected_query_graph(
entities: list[str], relationships: list[str], max_depth: int = 1
) -> KGExpandedGraphObjects:
"""
TODO: Implement this. For now we'll trust the SQL generation to do the right thing.
Return the original entities and relationships.
"""
return KGExpandedGraphObjects(entities=entities, relationships=relationships)
# def rename_entities_in_answer(answer: str) -> str:
# """
# Rename entities in the answer to be more readable by replacing entity references
# with their semantic_id and link. This is case-insensitive and handles spaces between
# entity type and ID. Trailing quotes are removed from entity names.
# """
# # Create a mapping of entity IDs to new names
# entity_mapping = {}
# with get_session_with_current_tenant() as db_session:
# # Get all entity types
# entity_types = get_entity_types(db_session)
# # For each entity type, find all entities in the answer
# for entity_type in entity_types:
# # Find all occurrences of <entity_type>:<entity_name> in the answer (case-insensitive)
# # Pattern now handles spaces after the colon
# pattern = f"{entity_type.id_name}:\\s*([^\\s,;.]+)"
# matches = re.finditer(pattern, answer, re.IGNORECASE)
# for match in matches:
# # Get the full match including any spaces
# full_match = match.group(0)
# # Get just the entity ID part (without spaces) and remove trailing quotes
# entity_name = match.group(1).rstrip("\"'")
# entity_id = f"{entity_type.id_name}:{entity_name}"
# if entity_id.lower() in entity_mapping:
# continue
# # Get the document for this entity
# entity = (
# db_session.query(KGEntity)
# .filter(
# KGEntity.id_name.ilike(
# entity_id
# ) # Case-insensitive comparison
# )
# .first()
# )
# if entity and entity.document_id:
# # Get the document's semantic_id and link
# document = (
# db_session.query(Document)
# .filter(Document.id == entity.document_id)
# .first()
# )
# if document:
# # Create the replacement text with semantic_id and link
# replacement = f"{document.semantic_id}"
# if document.link:
# replacement = f"[{replacement}]({document.link})"
# entity_mapping[entity_id.lower()] = replacement
# # Also map the full match (with spaces) to the same replacement
# entity_mapping[full_match.lower()] = replacement
# # Replace all entity references in the answer (case-insensitive)
# for entity_id, replacement in entity_mapping.items():
# # Use regex for case-insensitive replacement
# answer = re.sub(re.escape(entity_id), replacement, answer, flags=re.IGNORECASE)
# return answer
def stream_write_step_description(
writer: StreamWriter, step_nr: int, level: int = 0
) -> None:
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=STEP_DESCRIPTIONS[step_nr].description,
level=level,
level_question_num=step_nr,
),
writer,
)
# Give the frontend a brief moment to catch up
sleep(0.2)
def stream_write_step_activities(
writer: StreamWriter, step_nr: int, level: int = 0
) -> None:
for activity_nr, activity in enumerate(STEP_DESCRIPTIONS[step_nr].activities):
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query=activity,
level=level,
level_question_num=step_nr,
query_id=activity_nr + 1,
),
writer,
)
def stream_write_step_activity_explicit(
writer: StreamWriter, step_nr: int, query_id: int, activity: str, level: int = 0
) -> None:
for activity in STEP_DESCRIPTIONS[step_nr].activities:
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query=activity,
level=level,
level_question_num=step_nr,
query_id=query_id,
),
writer,
)
def stream_write_step_answer_explicit(
writer: StreamWriter, step_nr: int, answer: str, level: int = 0
) -> None:
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=answer,
level=level,
level_question_num=step_nr,
answer_type="agent_sub_answer",
),
writer,
)
def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
for step_nr, step_detail in STEP_DESCRIPTIONS.items():
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=step_detail.description,
level=level,
level_question_num=step_nr,
),
writer,
)
for step_nr in STEP_DESCRIPTIONS.keys():
write_custom_event(
"stream_finished",
StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=level,
level_question_num=step_nr,
),
writer,
)
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=0,
)
write_custom_event("stream_finished", stop_event, writer)
def stream_close_step_answer(
writer: StreamWriter, step_nr: int, level: int = 0
) -> None:
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_ANSWER,
level=level,
level_question_num=step_nr,
)
write_custom_event("stream_finished", stop_event, writer)
def stream_write_close_steps(writer: StreamWriter, level: int = 0) -> None:
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=level,
)
write_custom_event("stream_finished", stop_event, writer)
def stream_write_close_main_answer(writer: StreamWriter, level: int = 0) -> None:
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.MAIN_ANSWER,
level=level,
level_question_num=0,
)
write_custom_event("stream_finished", stop_event, writer)
def stream_write_main_answer_token(
writer: StreamWriter, token: str, level: int = 0, level_question_num: int = 0
) -> None:
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=token, # No need to add space as tokenizer handles this
level=level,
level_question_num=level_question_num,
answer_type="agent_level_answer",
),
writer,
)
def get_doc_information_for_entity(entity_id_name: str) -> KGEntityDocInfo:
"""
Get document information for an entity, including its semantic name and document details.
"""
if ":" not in entity_id_name:
return KGEntityDocInfo(
doc_id=None,
doc_semantic_id=None,
doc_link=None,
semantic_entity_name=entity_id_name,
semantic_linked_entity_name=entity_id_name,
)
entity_type, entity_name = map(str.strip, entity_id_name.split(":", 1))
with get_session_with_current_tenant() as db_session:
entity_document_id = get_document_id_for_entity(db_session, entity_id_name)
if entity_document_id:
return get_kg_doc_info_for_entity_name(
db_session, entity_document_id, entity_type
)
else:
return KGEntityDocInfo(
doc_id=None,
doc_semantic_id=None,
doc_link=None,
semantic_entity_name=entity_id_name,
semantic_linked_entity_name=entity_id_name,
)
def rename_entities_in_answer(answer: str) -> str:
"""
Process entity references in the answer string by:
1. Extracting all strings matching <str>:<str> or <str>: <str> patterns
2. Looking up these references in the entity table
3. Replacing valid references with their corresponding values
Args:
answer: The input string containing potential entity references
Returns:
str: The processed string with entity references replaced
"""
# Extract all entity references using regex
# Pattern matches both <str>:<str> and <str>: <str> formats
pattern = r"([^:\s]+):\s*([^\s,;.]+)"
matches = re.finditer(pattern, answer)
# get active entity types
with get_session_with_current_tenant() as db_session:
active_entity_types = [
x.id_name for x in get_entity_types(db_session, active=True)
]
# Collect extracted references
entity_refs = [match.group(0).strip(":") for match in matches]
# Create dictionary for processed references
processed_refs = {}
for entity_ref in entity_refs:
entity_ref_split = entity_ref.split(":")
if len(entity_ref_split) != 2:
logger.warning(
f"Invalid entity reference - number of colons is not 2 but {len(entity_ref_split)}"
)
continue
entity_type, entity_name = entity_ref.split(":")
entity_type = entity_type.upper().strip()
if entity_type not in active_entity_types:
continue
entity_name = entity_name.capitalize().strip()
potential_entity_id_name = f"{entity_type}:{entity_name}"
replacement_candidate = get_doc_information_for_entity(potential_entity_id_name)
if replacement_candidate.doc_id:
processed_refs[entity_ref] = (
replacement_candidate.semantic_linked_entity_name
)
else:
continue
# Replace all references in the answer
for ref, replacement in processed_refs.items():
answer = answer.replace(ref, replacement)
return answer
def build_document_context(
document: InferenceSection | LlmDoc, document_number: int
) -> str:
"""
Build a context string for a document.
"""
metadata_list: list[str] = []
document_content: str | None = None
info_source: InferenceChunk | LlmDoc | None = None
info_content: str | None = None
if isinstance(document, InferenceSection):
info_source = document.center_chunk
info_content = document.combined_content
elif isinstance(document, LlmDoc):
info_source = document
info_content = document.content
for key, value in info_source.metadata.items():
metadata_list.append(f" - {key}: {value}")
if metadata_list:
metadata_str = "- Document Metadata:\n" + "\n".join(metadata_list)
else:
metadata_str = ""
# Construct document header with number and semantic identifier
doc_header = f"Document {str(document_number)}: {info_source.semantic_identifier}"
# Combine all parts with proper spacing
document_content = f"{doc_header}\n\n{metadata_str}\n\n{info_content}"
return document_content
def get_near_empty_step_results(
step_number: int,
step_answer: str,
verified_reranked_documents: list[InferenceSection] = [],
) -> SubQuestionAnswerResults:
"""
Get near-empty step results from a list of step results.
"""
return SubQuestionAnswerResults(
question=STEP_DESCRIPTIONS[step_number].description,
question_id="0_" + str(step_number),
answer=step_answer,
verified_high_quality=True,
sub_query_retrieval_results=[],
verified_reranked_documents=verified_reranked_documents,
context_documents=[],
cited_documents=[],
sub_question_retrieval_stats=AgentChunkRetrievalStats(
verified_count=None,
verified_avg_scores=None,
rejected_count=None,
rejected_avg_scores=None,
verified_doc_chunk_ids=[],
dismissed_doc_chunk_ids=[],
),
)

View File

@@ -0,0 +1,49 @@
from pydantic import BaseModel
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import YesNoEnum
class KGQuestionEntityExtractionResult(BaseModel):
entities: list[str]
terms: list[str]
time_filter: str | None
class KGAnswerApproach(BaseModel):
search_type: KGSearchType
search_strategy: KGAnswerStrategy
format: KGAnswerFormat
broken_down_question: str | None = None
divide_and_conquer: YesNoEnum | None = None
class KGQuestionRelationshipExtractionResult(BaseModel):
relationships: list[str]
class KGQuestionExtractionResult(BaseModel):
entities: list[str]
relationships: list[str]
terms: list[str]
time_filter: str | None
class KGExpandedGraphObjects(BaseModel):
entities: list[str]
relationships: list[str]
class KGSteps(BaseModel):
description: str
activities: list[str]
class KGEntityDocInfo(BaseModel):
doc_id: str | None
doc_semantic_id: str | None
doc_link: str | None
semantic_entity_name: str
semantic_linked_entity_name: str

View File

@@ -0,0 +1,261 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
)
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_structure
from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
from onyx.agents.agent_search.kb_search.models import (
KGQuestionRelationshipExtractionResult,
)
from onyx.agents.agent_search.kb_search.states import ERTExtractionUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.relationships import get_allowed_relationship_type_pairs
from onyx.kg.extractions.extraction_processing import get_entity_types_str
from onyx.kg.extractions.extraction_processing import get_relationship_types_str
from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def extract_ert(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ERTExtractionUpdate:
"""
LangGraph node to start the agentic search process.
"""
# recheck KG enablement at outset KG graph
if not config["metadata"]["config"].behavior.kg_config_settings.KG_ENABLED:
logger.error("KG approach is not enabled, the KG agent flow cannot run.")
raise ValueError("KG approach is not enabled, the KG agent flow cannot run.")
_KG_STEP_NR = 1
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
if graph_config.tooling.search_tool is None:
raise ValueError("Search tool is not set")
elif graph_config.tooling.search_tool.user is None:
raise ValueError("User is not set")
else:
user_email = graph_config.tooling.search_tool.user.email
user_name = user_email.split("@")[0] or "unknown"
# first four lines duplicates from generate_initial_answer
question = graph_config.inputs.search_request.query
today_date = datetime.now().strftime("%A, %Y-%m-%d")
all_entity_types = get_entity_types_str(active=True)
all_relationship_types = get_relationship_types_str(active=True)
# Stream structure of substeps out to the UI
stream_write_step_structure(writer)
# Now specify core activities in the step (step 1)
stream_write_step_activities(writer, _KG_STEP_NR)
### get the entities, terms, and filters
query_extraction_pre_prompt = QUERY_ENTITY_EXTRACTION_PROMPT.format(
entity_types=all_entity_types,
relationship_types=all_relationship_types,
)
query_extraction_prompt = (
query_extraction_pre_prompt.replace("---content---", question)
.replace("---today_date---", today_date)
.replace("---user_name---", f"EMPLOYEE:{user_name}")
.replace("{{", "{")
.replace("}}", "}")
)
msg = [
HumanMessage(
content=query_extraction_prompt,
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
15,
primary_llm.invoke,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
cleaned_response = (
str(llm_response.content)
.replace("{{", "{")
.replace("}}", "}")
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
entity_extraction_result = (
KGQuestionEntityExtractionResult.model_validate_json(cleaned_response)
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
entity_extraction_result = KGQuestionEntityExtractionResult(
entities=[],
terms=[],
time_filter="",
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
entity_extraction_result = KGQuestionEntityExtractionResult(
entities=[],
terms=[],
time_filter="",
)
# remove the attribute filters from the entities to for the purpose of the relationship
entities_no_attributes = [
entity.split("--")[0] for entity in entity_extraction_result.entities
]
entities_string_for_relationships = f"Entities: {entities_no_attributes}\n"
ert_entities_string = f"Entities: {entities_string_for_relationships}\n"
### get the relationships
# find the relationship types that match the extracted entity types
with get_session_with_current_tenant() as db_session:
allowed_relationship_pairs = get_allowed_relationship_type_pairs(
db_session, entity_extraction_result.entities
)
query_relationship_extraction_prompt = (
QUERY_RELATIONSHIP_EXTRACTION_PROMPT.replace("---question---", question)
.replace("---today_date---", today_date)
.replace(
"---relationship_type_options---",
" - " + "\n - ".join(allowed_relationship_pairs),
)
.replace("---identified_entities---", ert_entities_string)
.replace("---entity_types---", all_entity_types)
.replace("{{", "{")
.replace("}}", "}")
)
msg = [
HumanMessage(
content=query_relationship_extraction_prompt,
)
]
fast_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
15,
fast_llm.invoke,
prompt=msg,
timeout_override=15,
max_tokens=300,
)
cleaned_response = (
str(llm_response.content)
.replace("{{", "{")
.replace("}}", "}")
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = cleaned_response.replace(" ", "")
cleaned_response = cleaned_response.replace("{{", '{"')
cleaned_response = cleaned_response.replace("}}", '"}')
try:
relationship_extraction_result = (
KGQuestionRelationshipExtractionResult.model_validate_json(
cleaned_response
)
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
relationships=[],
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
relationships=[],
)
## STEP 1
# Stream answer pieces out to the UI for Step 1
extracted_entity_string = " \n ".join(
[x.split("--")[0] for x in entity_extraction_result.entities]
)
extracted_relationship_string = " \n ".join(
relationship_extraction_result.relationships
)
step_answer = f"""Entities and relationships have been extracted from query - \n \
Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_string}"""
stream_write_step_answer_explicit(writer, step_nr=1, answer=step_answer)
# Finish Step 1
stream_close_step_answer(writer, _KG_STEP_NR)
return ERTExtractionUpdate(
entities_types_str=all_entity_types,
relationship_types_str=all_relationship_types,
extracted_entities_w_attributes=entity_extraction_result.entities,
extracted_entities_no_attributes=entities_no_attributes,
extracted_relationships=relationship_extraction_result.relationships,
extracted_terms=entity_extraction_result.terms,
time_filter=entity_extraction_result.time_filter,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="extract entities terms",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR,
step_answer=step_answer,
verified_reranked_documents=[],
)
],
)

View File

@@ -0,0 +1,274 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import (
create_minimal_connected_query_graph,
)
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
)
from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import YesNoEnum
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import get_document_id_for_entity
from onyx.kg.clustering.normalizations import normalize_entities
from onyx.kg.clustering.normalizations import normalize_entities_w_attributes_from_map
from onyx.kg.clustering.normalizations import normalize_relationships
from onyx.kg.clustering.normalizations import normalize_terms
from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def _get_fully_connected_entities(
entities: list[str], relationships: list[str]
) -> list[str]:
"""
Analyze the connectedness of the entities and relationships.
"""
# Build a dictionary to track connections for each entity
entity_connections: dict[str, set[str]] = {entity: set() for entity in entities}
# Parse relationships to build connection graph
for relationship in relationships:
# Split relationship into parts. Test for proper formatting just in case.
# Should never be an error though at this point.
parts = relationship.split("__")
if len(parts) != 3:
raise ValueError(f"Invalid relationship: {relationship}")
entity1 = parts[0]
entity2 = parts[2]
# Add bidirectional connections
if entity1 in entity_connections:
entity_connections[entity1].add(entity2)
if entity2 in entity_connections:
entity_connections[entity2].add(entity1)
# Find entities connected to all others
fully_connected_entities = []
all_entities = set(entities)
for entity, connections in entity_connections.items():
# Check if this entity is connected to all other entities
if connections == all_entities - {entity}:
fully_connected_entities.append(entity)
return fully_connected_entities
def _check_for_single_doc(
normalized_entities: list[str],
raw_entities: list[str],
normalized_relationships: list[str],
raw_relationships: list[str],
normalized_time_filter: str | None,
) -> str | None:
"""
Check if the query is for a single document, like 'Summarize ticket ENG-2243K'.
None is returned if the query is not for a single document.
"""
if (
len(normalized_entities) == 1
and len(raw_entities) == 1
and len(normalized_relationships) == 0
and len(raw_relationships) == 0
and normalized_time_filter is None
):
with get_session_with_current_tenant() as db_session:
single_doc_id = get_document_id_for_entity(
db_session, normalized_entities[0]
)
else:
single_doc_id = None
return single_doc_id
def analyze(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> AnalysisUpdate:
"""
LangGraph node to start the agentic search process.
"""
_KG_STEP_NR = 2
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
entities = (
state.extracted_entities_no_attributes
) # attribute knowledge is not required for this step
relationships = state.extracted_relationships
terms = state.extracted_terms
time_filter = state.time_filter
## STEP 2 - stream out goals
stream_write_step_activities(writer, _KG_STEP_NR)
# Continue with node
normalized_entities = normalize_entities(entities)
query_graph_entities_w_attributes = normalize_entities_w_attributes_from_map(
state.extracted_entities_w_attributes,
normalized_entities.entity_normalization_map,
)
normalized_relationships = normalize_relationships(
relationships, normalized_entities.entity_normalization_map
)
normalized_terms = normalize_terms(terms)
normalized_time_filter = time_filter
# If single-doc inquiry, send to single-doc processing directly
single_doc_id = _check_for_single_doc(
normalized_entities=normalized_entities.entities,
raw_entities=entities,
normalized_relationships=normalized_relationships.relationships,
raw_relationships=relationships,
normalized_time_filter=normalized_time_filter,
)
# Expand the entities and relationships to make sure that entities are connected
graph_expansion = create_minimal_connected_query_graph(
normalized_entities.entities,
normalized_relationships.relationships,
max_depth=2,
)
query_graph_entities = graph_expansion.entities
query_graph_relationships = graph_expansion.relationships
# Evaluate whether a search needs to be done after identifying all entities and relationships
strategy_generation_prompt = (
STRATEGY_GENERATION_PROMPT.replace(
"---entities---", "\n".join(query_graph_entities)
)
.replace("---relationships---", "\n".join(query_graph_relationships))
.replace("---possible_entities---", state.entities_types_str)
.replace("---possible_relationships---", state.relationship_types_str)
.replace("---question---", question)
)
msg = [
HumanMessage(
content=strategy_generation_prompt,
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
20,
# fast_llm.invoke,
primary_llm.invoke,
prompt=msg,
timeout_override=5,
max_tokens=100,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
approach_extraction_result = KGAnswerApproach.model_validate_json(
cleaned_response
)
search_type = approach_extraction_result.search_type
search_strategy = approach_extraction_result.search_strategy
output_format = approach_extraction_result.format
broken_down_question = approach_extraction_result.broken_down_question
divide_and_conquer = approach_extraction_result.divide_and_conquer
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
search_type = KGSearchType.SEARCH
search_strategy = KGAnswerStrategy.DEEP
output_format = KGAnswerFormat.TEXT
broken_down_question = None
divide_and_conquer = YesNoEnum.NO
if search_strategy is None or output_format is None:
raise ValueError(f"Invalid strategy: {cleaned_response}")
except Exception as e:
logger.error(f"Error in strategy generation: {e}")
raise e
# Stream out relevant results
if single_doc_id:
search_strategy = (
KGAnswerStrategy.DEEP
) # if a single doc is identified, we will want to look at the details.
step_answer = f"Strategy and format have been extracted from query. Strategy: {search_strategy.value}, \
Format: {output_format.value}, Broken down question: {broken_down_question}"
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
stream_close_step_answer(writer, _KG_STEP_NR)
# End node
return AnalysisUpdate(
normalized_core_entities=normalized_entities.entities,
normalized_core_relationships=normalized_relationships.relationships,
query_graph_entities_no_attributes=query_graph_entities,
query_graph_entities_w_attributes=query_graph_entities_w_attributes,
query_graph_relationships=query_graph_relationships,
normalized_terms=normalized_terms.terms,
normalized_time_filter=normalized_time_filter,
strategy=search_strategy,
broken_down_question=broken_down_question,
output_format=output_format,
divide_and_conquer=divide_and_conquer,
single_doc_id=single_doc_id,
search_type=search_type,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="analyze",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR,
step_answer=step_answer,
verified_reranked_documents=[],
)
],
)

View File

@@ -0,0 +1,394 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from sqlalchemy import text
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
)
from onyx.agents.agent_search.kb_search.states import KGSearchType
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import SQLSimpleGenerationUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.engine import get_db_readonly_user_session_with_current_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.kg_temp_view import create_views
from onyx.db.kg_temp_view import drop_views
from onyx.llm.interfaces import LLM
from onyx.prompts.kg_prompts import SIMPLE_SQL_CORRECTION_PROMPT
from onyx.prompts.kg_prompts import SIMPLE_SQL_PROMPT
from onyx.prompts.kg_prompts import SOURCE_DETECTION_PROMPT
from onyx.prompts.kg_prompts import SQL_AGGREGATION_REMOVAL_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def _sql_is_aggregate_query(sql_statement: str) -> bool:
return any(
agg_func in sql_statement.upper()
for agg_func in ["COUNT(", "MAX(", "MIN(", "AVG(", "SUM("]
)
def _remove_aggregation(sql_statement: str, llm: LLM) -> str:
"""
Remove aggregate functions from the SQL statement.
"""
sql_aggregation_removal_prompt = SQL_AGGREGATION_REMOVAL_PROMPT.replace(
"---sql_statement---", sql_statement
)
msg = [
HumanMessage(
content=sql_aggregation_removal_prompt,
)
]
try:
llm_response = run_with_timeout(
15,
llm.invoke,
prompt=msg,
timeout_override=25,
max_tokens=800,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
sql_statement = cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
sql_statement = sql_statement.replace("sql", "").strip()
except Exception as e:
logger.error(f"Error in strategy generation: {e}")
raise e
return sql_statement
def _get_source_documents(sql_statement: str, llm: LLM) -> str | None:
"""
Remove aggregate functions from the SQL statement.
"""
source_detection_prompt = SOURCE_DETECTION_PROMPT.replace(
"---original_sql_statement---", sql_statement
)
msg = [
HumanMessage(
content=source_detection_prompt,
)
]
cleaned_response: str | None = None
try:
llm_response = run_with_timeout(
15,
llm.invoke,
prompt=msg,
timeout_override=25,
max_tokens=1200,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
sql_statement = cleaned_response.split("<sql>")[1].strip()
sql_statement = sql_statement.split("</sql>")[0].strip()
except Exception as e:
if cleaned_response is not None:
logger.error(
f"Could not generate source documents SQL: {e}. Original model response: {cleaned_response}"
)
else:
logger.error(f"Could not generate source documents SQL: {e}")
return None
return sql_statement
def generate_simple_sql(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> SQLSimpleGenerationUpdate:
"""
LangGraph node to start the agentic search process.
"""
_KG_STEP_NR = 3
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
entities_types_str = state.entities_types_str
relationship_types_str = state.relationship_types_str
single_doc_id = state.single_doc_id
state.search_type
## STEP 3 - articulate goals
stream_write_step_activities(writer, _KG_STEP_NR)
if graph_config.tooling.search_tool is None:
raise ValueError("Search tool is not set")
elif graph_config.tooling.search_tool.user is None:
raise ValueError("User is not set")
else:
user_email = graph_config.tooling.search_tool.user.email
user_name = user_email.split("@")[0]
if state.search_type == KGSearchType.SQL and single_doc_id:
# If single doc id already identified, we do not need to go through the KG
# query cycle, saving a lot of time.
main_sql_statement: str | None = None
query_results: list[dict[str, Any]] | None = None
source_documents_sql: str | None = None
source_document_results: list[str] | None = [single_doc_id]
reasoning: str | None = (
f"A KG query was not required as the source document was already identified: {single_doc_id}"
)
step_answer = f"Source document already identified: {single_doc_id}"
elif state.search_type == KGSearchType.SEARCH:
# If we do a filtered search, then we do not need to go through the SQL
# generation process.
main_sql_statement = None
query_results = None
source_documents_sql = None
source_document_results = None
reasoning = "A KG query was not required as we will use a filtered search."
step_answer = "Filtered search will be used."
else:
# If no single doc id already identified, we need to go through the KG
# query cycle, including generating the SQL for the answer and the sources
# Build prompt
simple_sql_prompt = (
SIMPLE_SQL_PROMPT.replace("---entity_types---", entities_types_str)
.replace("---relationship_types---", relationship_types_str)
.replace("---question---", question)
.replace(
"---query_entities_with_attributes---",
"\n".join(state.query_graph_entities_w_attributes),
)
.replace(
"---query_relationships---", "\n".join(state.query_graph_relationships)
)
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
.replace("---user_name---", f"EMPLOYEE:{user_name}")
)
logger.debug(f"simple_sql_prompt: {simple_sql_prompt}")
# Create temporary view
allowed_docs_view_name = f"allowed_docs_{user_email}".replace("@", "_").replace(
".", "_"
)
kg_relationships_view_name = (
f"kg_relationships_with_access_{user_email}".replace("@", "_").replace(
".", "_"
)
)
with get_session_with_current_tenant() as db_session:
create_views(
db_session,
user_email=user_email,
allowed_docs_view_name=allowed_docs_view_name,
kg_relationships_view_name=kg_relationships_view_name,
)
# prepare SQL query generation
msg = [
HumanMessage(
content=simple_sql_prompt,
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
15,
primary_llm.invoke,
prompt=msg,
timeout_override=25,
max_tokens=1500,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
sql_statement = (
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
)
sql_statement = sql_statement.split(";")[0].strip() + ";"
sql_statement = sql_statement.replace("sql", "").strip()
sql_statement = sql_statement.replace(
"kg_relationship", kg_relationships_view_name
)
reasoning = (
cleaned_response.split("<reasoning>")[1]
.strip()
.split("</reasoning>")[0]
)
except Exception as e:
logger.error(f"Error in strategy generation: {e}")
raise e
# Correction if needed:
correction_prompt = SIMPLE_SQL_CORRECTION_PROMPT.replace(
"---draft_sql---", sql_statement
)
msg = [
HumanMessage(
content=correction_prompt,
)
]
try:
llm_response = run_with_timeout(
15,
primary_llm.invoke,
prompt=msg,
timeout_override=25,
max_tokens=1500,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
sql_statement = (
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
)
sql_statement = sql_statement.split(";")[0].strip() + ";"
sql_statement = sql_statement.replace("sql", "").strip()
except Exception as e:
logger.error(
f"Error in generating the sql correction: {e}. Original model response: {cleaned_response}"
)
raise e
# Get SQL for source documents
source_documents_sql = _get_source_documents(sql_statement, llm=primary_llm)
logger.debug(f"source_documents_sql: {source_documents_sql}")
scalar_result = None
query_results = None
with get_db_readonly_user_session_with_current_tenant() as db_session:
try:
result = db_session.execute(text(sql_statement))
# Handle scalar results (like COUNT)
if sql_statement.upper().startswith("SELECT COUNT"):
scalar_result = result.scalar()
query_results = (
[{"count": int(scalar_result)}]
if scalar_result is not None
else []
)
else:
# Handle regular row results
rows = result.fetchall()
query_results = [dict(row._mapping) for row in rows]
except Exception as e:
logger.error(f"Error executing SQL query: {e}")
raise e
source_document_results = None
if source_documents_sql is not None and source_documents_sql != sql_statement:
with get_db_readonly_user_session_with_current_tenant() as db_session:
try:
result = db_session.execute(text(source_documents_sql))
rows = result.fetchall()
query_source_document_results = [dict(row._mapping) for row in rows]
source_document_results = [
source_document_result["source_document"]
for source_document_result in query_source_document_results
]
except Exception as e:
# No stopping here, the individualized SQL query is not mandatory
logger.error(f"Error executing Individualized SQL query: {e}")
# individualized_query_results = None
else:
source_document_results = None
with get_session_with_current_tenant() as db_session:
drop_views(
db_session,
allowed_docs_view_name=allowed_docs_view_name,
kg_relationships_view_name=kg_relationships_view_name,
)
logger.info(f"query_results: {query_results}")
logger.debug(f"sql_statement: {sql_statement}")
# Stream out reasoning and SQL query
step_answer = f"Reasoning: {reasoning} \n \n SQL Query: {sql_statement}"
main_sql_statement = sql_statement
if reasoning:
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=reasoning)
stream_close_step_answer(writer, _KG_STEP_NR)
return SQLSimpleGenerationUpdate(
sql_query=main_sql_statement,
sql_query_results=query_results,
individualized_sql_query=None,
individualized_query_results=None,
source_documents_sql=source_documents_sql,
source_document_results=source_document_results or [],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate simple sql",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR,
step_answer=step_answer,
verified_reranked_documents=[],
)
],
)

View File

@@ -0,0 +1,178 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
from onyx.agents.agent_search.kb_search.states import KGFilterConstructionResults
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entity_type import get_entity_types_with_grounded_source_name
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def construct_deep_search_filters(
state: MainState, config: RunnableConfig, writer: StreamWriter
) -> DeepSearchFilterUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
entities_types_str = state.entities_types_str
entities = state.query_graph_entities_no_attributes
relationships = state.query_graph_relationships
simple_sql_query = state.sql_query
simple_sql_results = state.sql_query_results
source_document_results = state.source_document_results
if simple_sql_results:
simple_sql_results_str = "\n".join([str(x) for x in simple_sql_results])
else:
simple_sql_results_str = "(no SQL results generated)"
if source_document_results:
source_document_results_str = "\n".join(
[str(x) for x in source_document_results]
)
else:
source_document_results_str = "(no source document results generated)"
search_filter_construction_prompt = (
SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
"---entity_type_descriptions---",
entities_types_str,
)
.replace(
"---entity_filters---",
"\n".join(entities),
)
.replace(
"---relationship_filters---",
"\n".join(relationships),
)
.replace(
"---sql_query---",
simple_sql_query or "(no SQL generated)",
)
.replace(
"---sql_results---",
simple_sql_results_str or "(no SQL results generated)",
)
.replace(
"---source_document_results---",
source_document_results_str or "(no source document results generated)",
)
.replace(
"---question---",
question,
)
)
msg = [
HumanMessage(
content=search_filter_construction_prompt,
)
]
llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
15,
llm.invoke,
prompt=msg,
timeout_override=15,
max_tokens=1400,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
if last_bracket == -1 or first_bracket == -1:
raise ValueError("No valid JSON found in LLM response - no brackets found")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
cleaned_response = cleaned_response.replace("{{", '{"')
cleaned_response = cleaned_response.replace("}}", '"}')
try:
filter_results = KGFilterConstructionResults.model_validate_json(
cleaned_response
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
filter_results = KGFilterConstructionResults(
global_entity_filters=[],
global_relationship_filters=[],
local_entity_filters=[],
source_document_filters=[],
structure=[],
)
except Exception as e:
logger.error(f"Error in extract_ert: {e}")
filter_results = KGFilterConstructionResults(
global_entity_filters=[],
global_relationship_filters=[],
local_entity_filters=[],
source_document_filters=[],
structure=[],
)
div_con_structure = filter_results.structure
logger.info(f"div_con_structure: {div_con_structure}")
with get_session_with_current_tenant() as db_session:
double_grounded_entity_types = get_entity_types_with_grounded_source_name(
db_session
)
source_division = False
if div_con_structure:
for entity_type in double_grounded_entity_types:
if entity_type.grounded_source_name.lower() in div_con_structure[0].lower():
source_division = True
break
return DeepSearchFilterUpdate(
vespa_filter_results=filter_results,
div_con_entities=div_con_structure,
source_division=source_division,
global_entity_filters=[
f"{global_filter}:*"
for global_filter in filter_results.global_entity_filters
],
global_relationship_filters=filter_results.global_relationship_filters,
local_entity_filters=filter_results.local_entity_filters,
source_filters=filter_results.source_document_filters,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="construct deep search filters",
node_start_time=node_start_time,
)
],
step_results=[],
)

View File

@@ -0,0 +1,180 @@
import copy
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.kb_search.graph_utils import (
get_doc_information_for_entity,
)
from onyx.agents.agent_search.kb_search.graph_utils import write_custom_event
from onyx.agents.agent_search.kb_search.ops import research
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
from onyx.agents.agent_search.kb_search.states import ResearchObjectUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.chat.models import LlmDoc
from onyx.chat.models import SubQueryPiece
from onyx.context.search.models import InferenceSection
from onyx.prompts.kg_prompts import KG_OBJECT_SOURCE_RESEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def process_individual_deep_search(
state: ResearchObjectInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ResearchObjectUpdate:
"""
LangGraph node to start the agentic search process.
"""
_KG_STEP_NR = 4
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = state.broken_down_question
source_division = state.source_division
source_filters = state.source_entity_filters
object = state.entity.replace(":", ": ").lower()
object_id = object.split(":")[1].strip()
if not search_tool:
raise ValueError("search_tool is not provided")
research_nr = state.research_nr
extended_question = f"{question} in regards to {object}"
if source_division:
extended_question = question
raw_kg_entity_filters = copy.deepcopy(
list(set((state.vespa_filter_results.global_entity_filters + [state.entity])))
)
kg_entity_filters = []
for raw_kg_entity_filter in raw_kg_entity_filters:
if ":" not in raw_kg_entity_filter:
raw_kg_entity_filter += ":*"
kg_entity_filters.append(raw_kg_entity_filter)
kg_relationship_filters = copy.deepcopy(
state.vespa_filter_results.global_relationship_filters
)
# if this is a single-object analysis and object is a source,
# drop the other filters as they already were included
# in the KG query that led to this analysis
if source_division and source_filters:
for source_filter in source_filters:
if object_id.lower() == source_filter.lower():
source_filters = [source_filter]
kg_relationship_filters = []
kg_entity_filters = []
break
logger.info("Research for object: " + object)
logger.info(f"kg_entity_filters: {kg_entity_filters}")
logger.info(f"kg_relationship_filters: {kg_relationship_filters}")
# Step 4 - stream out the research query
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query=f"{get_doc_information_for_entity(object).semantic_entity_name}",
level=0,
level_question_num=_KG_STEP_NR,
query_id=research_nr + 1,
),
writer,
)
retrieved_docs = research(
question=extended_question,
kg_entities=kg_entity_filters,
kg_relationships=kg_relationship_filters,
kg_sources=source_filters,
search_tool=search_tool,
)
document_texts_list = []
for doc_num, retrieved_doc in enumerate(retrieved_docs):
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# Built prompt
datetime.now().strftime("%A, %Y-%m-%d")
kg_object_source_research_prompt = KG_OBJECT_SOURCE_RESEARCH_PROMPT.format(
question=extended_question,
document_text=document_texts,
)
# Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=kg_object_source_research_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
llm = primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
llm.invoke,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
object_research_results = str(llm_response.content).replace("```json\n", "")
except Exception as e:
raise ValueError(f"Error in research_object_source: {e}")
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
return ResearchObjectUpdate(
research_object_results=[
{
"object": object.replace(":", ": ").capitalize(),
"results": object_research_results,
}
],
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="process individual deep search",
node_start_time=node_start_time,
)
],
step_results=[],
)

View File

@@ -0,0 +1,176 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
)
from onyx.agents.agent_search.kb_search.graph_utils import write_custom_event
from onyx.agents.agent_search.kb_search.ops import research
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.chat.models import SubQueryPiece
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
from onyx.context.search.models import InferenceSection
from onyx.prompts.kg_prompts import KG_SEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def filtered_search(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ConsolidatedResearchUpdate:
"""
LangGraph node to do a filtered search.
"""
_KG_STEP_NR = 4
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.search_request.query
if not search_tool:
raise ValueError("search_tool is not provided")
if not state.vespa_filter_results:
raise ValueError("vespa_filter_results is not provided")
raw_kg_entity_filters = list(
set((state.vespa_filter_results.global_entity_filters))
)
kg_entity_filters = []
for raw_kg_entity_filter in raw_kg_entity_filters:
if ":" not in raw_kg_entity_filter:
raw_kg_entity_filter += ":*"
kg_entity_filters.append(raw_kg_entity_filter)
kg_relationship_filters = state.vespa_filter_results.global_relationship_filters
logger.info("Starting filtered search")
logger.debug(f"kg_entity_filters: {kg_entity_filters}")
logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
# Step 4 - stream out the research query
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query="Conduct a filtered search",
level=0,
level_question_num=_KG_STEP_NR,
query_id=1,
),
writer,
)
retrieved_docs = cast(
list[InferenceSection],
research(
question=question,
kg_entities=kg_entity_filters,
kg_relationships=kg_relationship_filters,
kg_sources=None,
search_tool=search_tool,
inference_sections_only=True,
),
)
answer_generation_documents = get_answer_generation_documents(
relevant_docs=retrieved_docs,
context_documents=retrieved_docs,
original_question_docs=retrieved_docs,
max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
)
document_texts_list = []
for doc_num, retrieved_doc in enumerate(
answer_generation_documents.context_documents
):
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# Built prompt
datetime.now().strftime("%A, %Y-%m-%d")
kg_object_source_research_prompt = KG_SEARCH_PROMPT.format(
question=question,
document_text=document_texts,
)
# Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=kg_object_source_research_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
llm = primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
llm.invoke,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
filtered_search_answer = str(llm_response.content).replace("```json\n", "")
except Exception as e:
raise ValueError(f"Error in research_object_source: {e}")
step_answer = "Filtered search is complete."
stream_write_step_answer_explicit(
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
)
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
return ConsolidatedResearchUpdate(
consolidated_research_object_results_str=filtered_search_answer,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate simple sql",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR,
step_answer=step_answer,
verified_reranked_documents=retrieved_docs,
)
],
)

View File

@@ -0,0 +1,70 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
)
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def consolidate_individual_deep_search(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ConsolidatedResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
_KG_STEP_NR = 4
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
state.entities_types_str
research_object_results = state.research_object_results
consolidated_research_object_results_str = "\n".join(
[f"{x['object']}: {x['results']}" for x in research_object_results]
)
consolidated_research_object_results_str = rename_entities_in_answer(
consolidated_research_object_results_str
)
step_answer = "All research is complete. Consolidating results..."
stream_write_step_answer_explicit(
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
)
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
return ConsolidatedResearchUpdate(
consolidated_research_object_results_str=consolidated_research_object_results_str,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate simple sql",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR, step_answer=step_answer
)
],
)

View File

@@ -0,0 +1,135 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_step_answer_explicit,
)
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import SubQueryPiece
from onyx.db.document import get_base_llm_doc_information
from onyx.db.engine import get_session_with_current_tenant
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _general_format(result: dict[str, Any]) -> str:
name = result.get("name")
entity_type_id_name: Any = result.get("entity_type_id_name")
result.get("id_name")
assert entity_type_id_name is str
return f"{entity_type_id_name.capitalize()}: {name}"
def _get_formated_source_reference_results(
source_document_results: list[str] | None,
) -> str | None:
"""
Generate reference results from the query results data string.
"""
if source_document_results is None:
return None
# get all entities that correspond to an Onyx document
document_ids = source_document_results
with get_session_with_current_tenant() as session:
llm_doc_information_results = get_base_llm_doc_information(
session, document_ids
)
if len(llm_doc_information_results) == 0:
return ""
return (
f"\n \n Here are {len(llm_doc_information_results)} examples: \n \n "
+ " \n \n ".join(llm_doc_information_results)
)
def process_kg_only_answers(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ResultsDataUpdate:
"""
LangGraph node to start the agentic search process.
"""
_KG_STEP_NR = 4
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
query_results = state.sql_query_results
state.individualized_query_results
source_document_results = state.source_document_results
# we use this stream write explicitly
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query="Formatted References",
level=0,
level_question_num=_KG_STEP_NR,
query_id=1,
),
writer,
)
query_results_list = []
if query_results:
for query_result in query_results:
query_results_list.append(str(query_result).replace(":", ": ").capitalize())
else:
raise ValueError("No query results were found")
query_results_data_str = "\n".join(query_results_list)
source_reference_result_str = _get_formated_source_reference_results(
source_document_results
)
## STEP 4 - same components as Step 1
step_answer = (
"No further research is needed, the answer is derived from the knowledge graph."
)
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
stream_close_step_answer(writer, _KG_STEP_NR)
return ResultsDataUpdate(
query_results_data_str=query_results_data_str,
individualized_query_results_data_str="",
reference_results_str=source_reference_result_str,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="kg query results data processing",
node_start_time=node_start_time,
)
],
step_results=[
get_near_empty_step_results(
step_number=_KG_STEP_NR, step_answer=step_answer
)
],
)

View File

@@ -0,0 +1,284 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_close_main_answer,
)
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_close_steps
from onyx.agents.agent_search.kb_search.graph_utils import (
stream_write_main_answer_token,
)
from onyx.agents.agent_search.kb_search.ops import research
from onyx.agents.agent_search.kb_search.states import MainOutput
from onyx.agents.agent_search.kb_search.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import ExtendedToolResponse
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
from onyx.context.search.models import InferenceSection
from onyx.db.chat import log_agent_sub_question_results
from onyx.db.engine import get_session_with_current_tenant
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_EXAMPLES_PROMPT
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT
from onyx.tools.tool_implementations.search.search_tool import IndexFilters
from onyx.tools.tool_implementations.search.search_tool import SearchQueryInfo
from onyx.tools.tool_implementations.search.search_tool import SearchType
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def generate_answer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> MainOutput:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
search_tool = graph_config.tooling.search_tool
if search_tool is None:
raise ValueError("Search tool is not set")
# Close out previous streams of steps
# DECLARE STEPS DONE
stream_write_close_steps(writer)
## MAIN ANSWER
# identify whether documents have already been retrieved
retrieved_docs: list[InferenceSection] = []
for step_result in state.step_results:
retrieved_docs += step_result.verified_reranked_documents
# if still needed, get a search done and send the results to the UI
if not retrieved_docs and state.source_document_results:
assert graph_config.tooling.search_tool is not None
retrieved_docs = cast(
list[InferenceSection],
research(
question=question,
kg_entities=[],
kg_relationships=[],
kg_sources=state.source_document_results[
:KG_RESEARCH_NUM_RETRIEVED_DOCS
],
search_tool=graph_config.tooling.search_tool,
kg_chunk_id_zero_only=True,
inference_sections_only=True,
),
)
answer_generation_documents = get_answer_generation_documents(
relevant_docs=retrieved_docs,
context_documents=retrieved_docs,
original_question_docs=retrieved_docs,
max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
)
relevance_list = relevance_from_docs(
answer_generation_documents.streaming_documents
)
assert graph_config.tooling.search_tool is not None
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=SearchQueryInfo(
predicted_search=SearchType.KEYWORD,
# acl here is empty, because the searach alrady happened and
# we are streaming out the results.
final_filters=IndexFilters(access_control_list=[]),
recency_bias_multiplier=1.0,
),
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=0,
level_question_num=0, # 0, 0 is the base question
),
writer,
)
# continue with the answer generation
output_format = (
state.output_format.value
if state.output_format
else "<you be the judge how to best present the data>"
)
# if deep path was taken:
consolidated_research_object_results_str = (
state.consolidated_research_object_results_str
)
reference_results_str = (
state.reference_results_str
) # will not be part of LLM. Manually added to the answer
# if simple path was taken:
introductory_answer = state.query_results_data_str # from simple answer path only
if consolidated_research_object_results_str:
research_results = consolidated_research_object_results_str
else:
research_results = ""
if introductory_answer:
output_format_prompt = (
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
.replace(
"---introductory_answer---",
rename_entities_in_answer(introductory_answer),
)
.replace("---output_format---", str(output_format) if output_format else "")
)
elif research_results and consolidated_research_object_results_str:
output_format_prompt = (
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
.replace(
"---introductory_answer---",
rename_entities_in_answer(consolidated_research_object_results_str),
)
.replace("---output_format---", str(output_format) if output_format else "")
)
elif research_results and not consolidated_research_object_results_str:
output_format_prompt = (
OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT.replace("---question---", question)
.replace("---output_format---", str(output_format) if output_format else "")
.replace(
"---research_results---", rename_entities_in_answer(research_results)
)
)
elif consolidated_research_object_results_str:
output_format_prompt = (
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
.replace("---output_format---", str(output_format) if output_format else "")
.replace(
"---research_results---", rename_entities_in_answer(research_results)
)
)
else:
raise ValueError("No research results or introductory answer provided")
msg = [
HumanMessage(
content=output_format_prompt,
)
]
fast_llm = graph_config.tooling.fast_llm
dispatch_timings: list[float] = []
response: list[str] = []
def stream_answer() -> list[str]:
# Get the LLM's tokenizer
llm_tokenizer = get_tokenizer(
model_name=fast_llm.config.model_name,
provider_type=fast_llm.config.model_provider,
)
for message in fast_llm.stream(
prompt=msg,
timeout_override=30,
max_tokens=1000,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
# Tokenize the content using the LLM's tokenizer
tokens = llm_tokenizer.tokenize(content)
for token in tokens:
start_stream_token = datetime.now()
stream_write_main_answer_token(
writer, token, level=0, level_question_num=0
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(token)
return response
try:
response = run_with_timeout(
30,
stream_answer,
)
if reference_results_str:
# Get the LLM's tokenizer
llm_tokenizer = get_tokenizer(
model_name=fast_llm.config.model_name,
provider_type=fast_llm.config.model_provider,
)
# Tokenize and stream the reference results
tokens = llm_tokenizer.tokenize(reference_results_str)
for token in tokens:
# Replace newlines with HTML line breaks
# if token == ' \n':
# token = ' <br>'
stream_write_main_answer_token(writer, token)
except Exception as e:
raise ValueError(f"Could not generate the answer. Error {e}")
stream_write_close_main_answer(writer)
# Persist the sub-answer in the database
with get_session_with_current_tenant() as db_session:
chat_session_id = graph_config.persistence.chat_session_id
primary_message_id = graph_config.persistence.message_id
sub_question_answer_results = state.step_results
log_agent_sub_question_results(
db_session=db_session,
chat_session_id=chat_session_id,
primary_message_id=primary_message_id,
sub_question_answer_results=sub_question_answer_results,
)
return MainOutput(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="query completed",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,65 @@
from datetime import datetime
from typing import cast
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
def research(
question: str,
search_tool: SearchTool,
document_sources: list[DocumentSource] | None = None,
time_cutoff: datetime | None = None,
kg_entities: list[str] | None = None,
kg_relationships: list[str] | None = None,
kg_terms: list[str] | None = None,
kg_sources: list[str] | None = None,
kg_chunk_id_zero_only: bool = False,
inference_sections_only: bool = False,
) -> list[LlmDoc] | list[InferenceSection]:
# new db session to avoid concurrency issues
callback_container: list[list[InferenceSection]] = []
retrieved_docs: list[LlmDoc] | list[InferenceSection] = []
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=question,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
document_sources=document_sources,
time_cutoff=time_cutoff,
kg_entities=kg_entities,
kg_relationships=kg_relationships,
kg_terms=kg_terms,
kg_sources=kg_sources,
kg_chunk_id_zero_only=kg_chunk_id_zero_only,
),
):
if (
inference_sections_only
and tool_response.id == "search_response_summary"
):
retrieved_docs = tool_response.response.top_sections[
:KG_RESEARCH_NUM_RETRIEVED_DOCS
]
retrieved_docs = cast(list[InferenceSection], retrieved_docs)
break
# get retrieved docs to send to the rest of the graph
elif tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
retrieved_docs = cast(list[LlmDoc], tool_response.response)[
:KG_RESEARCH_NUM_RETRIEVED_DOCS
]
break
return retrieved_docs

View File

@@ -0,0 +1,163 @@
from enum import Enum
from operator import add
from typing import Annotated
from typing import Any
from typing import Dict
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import SubQuestionAnswerResults
from onyx.context.search.models import InferenceSection
### States ###
class StepResults(BaseModel):
question: str
question_id: str
answer: str
sub_query_retrieval_results: list[QueryRetrievalResult]
verified_reranked_documents: list[InferenceSection]
context_documents: list[InferenceSection]
cited_documents: list[InferenceSection]
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
step_results: Annotated[list[SubQuestionAnswerResults], add]
class KGFilterConstructionResults(BaseModel):
global_entity_filters: list[str]
global_relationship_filters: list[str]
local_entity_filters: list[list[str]]
source_document_filters: list[str]
structure: list[str]
class KGSearchType(Enum):
SEARCH = "SEARCH"
SQL = "SQL"
class KGAnswerStrategy(Enum):
DEEP = "DEEP"
SIMPLE = "SIMPLE"
class KGAnswerFormat(Enum):
LIST = "LIST"
TEXT = "TEXT"
class YesNoEnum(str, Enum):
YES = "yes"
NO = "no"
class AnalysisUpdate(LoggerUpdate):
normalized_core_entities: list[str] = []
normalized_core_relationships: list[str] = []
query_graph_entities_no_attributes: list[str] = []
query_graph_entities_w_attributes: list[str] = []
query_graph_relationships: list[str] = []
normalized_terms: list[str] = []
normalized_time_filter: str | None = None
strategy: KGAnswerStrategy | None = None
output_format: KGAnswerFormat | None = None
broken_down_question: str | None = None
divide_and_conquer: YesNoEnum | None = None
single_doc_id: str | None = None
search_type: KGSearchType | None = None
class SQLSimpleGenerationUpdate(LoggerUpdate):
sql_query: str | None = None
sql_query_results: list[Dict[Any, Any]] | None = None
individualized_sql_query: str | None = None
individualized_query_results: list[Dict[Any, Any]] | None = None
source_documents_sql: str | None = None
source_document_results: list[str] | None = None
class ConsolidatedResearchUpdate(LoggerUpdate):
consolidated_research_object_results_str: str | None = None
class DeepSearchFilterUpdate(LoggerUpdate):
vespa_filter_results: KGFilterConstructionResults | None = None
div_con_entities: list[str] | None = None
source_division: bool | None = None
global_entity_filters: list[str] | None = None
global_relationship_filters: list[str] | None = None
local_entity_filters: list[list[str]] | None = None
source_filters: list[str] | None = None
class ResearchObjectOutput(LoggerUpdate):
research_object_results: Annotated[list[dict[str, Any]], add] = []
class ERTExtractionUpdate(LoggerUpdate):
entities_types_str: str = ""
relationship_types_str: str = ""
extracted_entities_w_attributes: list[str] = []
extracted_entities_no_attributes: list[str] = []
extracted_relationships: list[str] = []
extracted_terms: list[str] = []
time_filter: str | None = None
class ResultsDataUpdate(LoggerUpdate):
query_results_data_str: str | None = None
individualized_query_results_data_str: str | None = None
reference_results_str: str | None = None
class ResearchObjectUpdate(LoggerUpdate):
research_object_results: Annotated[list[dict[str, Any]], add] = []
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
ERTExtractionUpdate,
AnalysisUpdate,
SQLSimpleGenerationUpdate,
ResultsDataUpdate,
ResearchObjectOutput,
DeepSearchFilterUpdate,
ResearchObjectUpdate,
ConsolidatedResearchUpdate,
):
pass
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]
class ResearchObjectInput(LoggerUpdate):
research_nr: int
entity: str
broken_down_question: str
vespa_filter_results: KGFilterConstructionResults
source_division: bool | None
source_entity_filters: list[str] | None

View File

@@ -0,0 +1,29 @@
from onyx.agents.agent_search.kb_search.models import KGSteps
STEP_DESCRIPTIONS: dict[int, KGSteps] = {
1: KGSteps(
description="Analyzing the question...",
activities=[
"Entities in Query",
"Relationships in Query",
"Terms in Query",
"Time Filters",
],
),
2: KGSteps(
description="Planning the response approach...",
activities=["Query Execution Strategy", "Answer Format"],
),
3: KGSteps(
description="Querying the Knowledge Graph..",
activities=[
"Knowledge Graph Query",
"Knowledge Graph Query Results",
"Query for Source Documents",
"Source Documents",
],
),
4: KGSteps(
description="Conducting further research on source documents...", activities=[]
),
}

View File

@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.context.search.models import SearchRequest
from onyx.file_store.utils import InMemoryChatFile
from onyx.kg.models import KGConfigSettings
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
@@ -68,6 +69,7 @@ class GraphSearchConfig(BaseModel):
allow_refinement: bool = True
skip_gen_ai_answer_generation: bool = False
allow_agent_reranking: bool = False
kg_config_settings: KGConfigSettings = KGConfigSettings()
class GraphConfig(BaseModel):

View File

@@ -18,6 +18,8 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
from onyx.agents.agent_search.deep_search.main.states import (
MainInput as MainInput,
)
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.chat.models import AgentAnswerPiece
@@ -85,7 +87,7 @@ def _parse_agent_event(
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
graph_input: BasicInput | MainInput | DCMainInput,
graph_input: BasicInput | MainInput | DCMainInput | KBMainInput,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
for event in compiled_graph.stream(
@@ -99,7 +101,7 @@ def manage_sync_streaming(
def run_graph(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
input: BasicInput | MainInput | DCMainInput,
input: BasicInput | MainInput | DCMainInput | KBMainInput,
) -> AnswerStream:
config.behavior.perform_initial_search_decomposition = (
INITIAL_SEARCH_DECOMPOSITION_ENABLED
@@ -149,6 +151,21 @@ def run_basic_graph(
return run_graph(compiled_graph, config, input)
def run_kb_graph(
config: GraphConfig,
) -> AnswerStream:
graph = kb_graph_builder()
compiled_graph = graph.compile()
input = KBMainInput(log_messages=[])
yield ToolCallKickoff(
tool_name="agent_search_0",
tool_args={"query": config.inputs.search_request.query},
)
yield from run_graph(compiled_graph, config, input)
def run_dc_graph(
config: GraphConfig,
) -> AnswerStream:

View File

@@ -159,3 +159,8 @@ BaseMessage_Content = str | list[str | dict[str, Any]]
class QueryExpansionType(Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
class ReferenceResults(BaseModel):
citations: list[str]
general_entities: list[str]

View File

@@ -23,6 +23,7 @@ from onyx.db.document import mark_document_as_modified
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.relationships import delete_document_references_from_kg
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
@@ -119,6 +120,11 @@ def document_by_cc_pair_cleanup_task(
chunk_count=chunk_count,
)
delete_document_references_from_kg(
db_session=db_session,
document_id=document_id,
)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],

View File

@@ -11,6 +11,7 @@ from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_dc_graph
from onyx.agents.agent_search.run_graph import run_kb_graph
from onyx.agents.agent_search.run_graph import run_main_graph
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
@@ -22,8 +23,10 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import SubQuestionKey
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.chat_configs import USE_DIV_CON_AGENT
from onyx.configs.constants import BASIC_KEY
from onyx.context.search.models import SearchRequest
from onyx.db.kg_config import get_kg_config_settings
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
@@ -127,8 +130,9 @@ class Answer:
self.search_behavior_config = GraphSearchConfig(
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
allow_refinement=True,
allow_refinement=False,
allow_agent_reranking=allow_agent_reranking,
kg_config_settings=get_kg_config_settings(db_session),
)
self.graph_config = GraphConfig(
inputs=self.graph_inputs,
@@ -143,10 +147,19 @@ class Answer:
yield from self._processed_stream
return
if self.graph_config.behavior.use_agentic_search:
if self.graph_config.behavior.use_agentic_search and (
self.graph_config.inputs.search_request.persona
and self.graph_config.behavior.kg_config_settings.KG_ENABLED
and self.graph_config.inputs.search_request.persona.name.startswith(
"KG Dev"
)
):
run_langgraph = run_kb_graph
elif self.graph_config.behavior.use_agentic_search:
run_langgraph = run_main_graph
elif (
self.graph_config.inputs.search_request.persona
and USE_DIV_CON_AGENT
and self.graph_config.inputs.search_request.persona.description.startswith(
"DivCon Beta Agent"
)

View File

@@ -81,6 +81,7 @@ from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
from onyx.db.chat import update_chat_session_updated_at_timestamp
from onyx.db.engine import get_session_context_manager
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists
from onyx.db.milestone import update_user_assistant_milestone
@@ -100,6 +101,16 @@ from onyx.file_store.utils import get_user_files
from onyx.file_store.utils import load_all_chat_files
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.file_store.utils import save_files
from onyx.kg.clustering.incremental_cluster_updates import (
kg_incremental_cluster_updates,
)
from onyx.kg.clustering.initial_clustering import kg_clustering
from onyx.kg.configuration import populate_default_account_employee_definitions
from onyx.kg.configuration import populate_default_grounded_entity_types
from onyx.kg.extractions.extraction_processing import kg_extraction
from onyx.kg.resets.reset_extractions import reset_extraction_kg_index
from onyx.kg.resets.reset_index import reset_full_kg_index
from onyx.kg.resets.reset_normalizations import reset_normalization_kg_index
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
@@ -664,6 +675,46 @@ def stream_chat_message_objects(
llm: LLM
kg_config_settings = get_kg_config_settings(db_session)
if kg_config_settings.KG_ENABLED:
# Temporarily, until we have a draft UI for the KG Operations/Management
# get Vespa index
search_settings = get_current_search_settings(db_session)
index_str = search_settings.index_name
if new_msg_req.message == "kg_e":
kg_extraction(tenant_id, index_str)
raise Exception("Extractions done")
elif new_msg_req.message == "kg_c":
kg_clustering(tenant_id, index_str)
raise Exception("Clustering done")
elif new_msg_req.message == "kg_i":
kg_incremental_cluster_updates(tenant_id, index_str)
raise Exception("Incremental clustering done")
elif new_msg_req.message == "kg_rs_full":
reset_full_kg_index()
raise Exception("Full KG index reset done")
elif new_msg_req.message == "kg_rs_extraction":
reset_extraction_kg_index()
raise Exception("Extraction KG index reset done")
elif new_msg_req.message == "kg_rs_normalization":
reset_normalization_kg_index()
raise Exception("Normalization KG index reset done")
elif new_msg_req.message == "kg_setup":
populate_default_grounded_entity_types()
populate_default_account_employee_definitions()
raise Exception("KG setup done")
try:
# Move these variables inside the try block
file_id_to_user_file = {}

View File

@@ -184,6 +184,13 @@ POSTGRES_API_SERVER_POOL_OVERFLOW = int(
os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10
)
POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE = int(
os.environ.get("POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE") or 10
)
POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW = int(
os.environ.get("POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW") or 5
)
# defaults to False
# generally should only be used for
POSTGRES_USE_NULL_POOL = os.environ.get("POSTGRES_USE_NULL_POOL", "").lower() == "true"
@@ -746,3 +753,9 @@ IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
DISABLE_AUTO_AUTH_REFRESH = (
os.environ.get("DISABLE_AUTO_AUTH_REFRESH", "").lower() == "true"
)
# Knowledge Graph Read Only User Configuration
DB_READONLY_USER: str = os.environ.get("DB_READONLY_USER", "db_readonly_user")
DB_READONLY_PASSWORD: str = urllib.parse.quote_plus(
os.environ.get("DB_READONLY_PASSWORD") or "password"
)

View File

@@ -102,3 +102,5 @@ USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH = (
os.environ.get("USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH", "false").lower()
== "true"
)
USE_DIV_CON_AGENT = os.environ.get("USE_DIV_CON_AGENT", "false").lower() == "true"

View File

@@ -468,3 +468,8 @@ if platform.system() == "Darwin":
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
else:
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore
class OnyxCallTypes(str, Enum):
FIREFLIES = "fireflies"
GONG = "gong"

View File

@@ -0,0 +1,10 @@
import os
KG_RESEARCH_NUM_RETRIEVED_DOCS: int = int(
os.environ.get("KG_RESEARCH_NUM_RETRIEVED_DOCS", "25")
)
KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES: int = int(
os.environ.get("KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES", "10")
)

View File

@@ -193,12 +193,19 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
team {
name
}
assignee {
email
}
previousIdentifiers
subIssueSortOrder
priorityLabel
identifier
url
branchName
state {
id
name
}
customerTicketCount
description
comments {
@@ -267,7 +274,19 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
title=node["title"],
doc_updated_at=time_str_to_utc(node["updatedAt"]),
metadata={
"team": node["team"]["name"],
k: str(v)
for k, v in {
"team": (node.get("team") or {}).get("name"),
"assignee": (node.get("assignee") or {}).get("email"),
"state": (node.get("state") or {}).get("name"),
"priority": node.get("priority"),
"estimate": node.get("estimate"),
"started_at": node.get("startedAt"),
"completed_at": node.get("completedAt"),
"created_at": node.get("createdAt"),
"due_date": node.get("dueDate"),
}.items()
if v is not None
},
)
)

View File

@@ -35,6 +35,28 @@ logger = setup_logger()
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
"Opportunity": {
"Account": "account",
"FiscalQuarter": "fiscal_quarter",
"FiscalYear": "fiscal_year",
"IsClosed": "is_closed",
"Name": "name",
"StageName": "stage_name",
"Type": "type",
"Amount": "amount",
"CloseDate": "close_date",
"Probability": "probability",
"CreatedDate": "created_date",
"LastModifiedDate": "last_modified_date",
},
"Contact": {
"Account": "account",
"CreatedDate": "created_date",
"LastModifiedDate": "last_modified_date",
},
}
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
"""Approach outline
@@ -168,6 +190,8 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# Always want to make sure user is grabbed for permissioning purposes
all_types.add("User")
# Always want to make sure account is grabbed for reference purposes
all_types.add("Account")
logger.info(f"All object types: num={len(all_types)} list={all_types}")
@@ -279,7 +303,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
f"remaining={len(updated_ids) - docs_processed}"
)
for parent_id in parent_id_batch:
parent_object = sf_db.get_record(parent_id, parent_type)
parent_object = sf_db.get_record(
parent_id, parent_type, isChild=False
)
if not parent_object:
logger.warning(
f"Failed to get parent object {parent_id} for {parent_type}"
@@ -291,6 +317,18 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
sf_object=parent_object,
sf_instance=self.sf_client.sf_instance,
)
doc.metadata["object_type"] = parent_type
# Add default attributes to the metadata
for (
sf_attribute,
canonical_attribute,
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
if sf_attribute in parent_object.data:
doc.metadata[canonical_attribute] = parent_object.data[
sf_attribute
]
doc_sizeof = sys.getsizeof(doc)
docs_to_yield_bytes += doc_sizeof
docs_to_yield.append(doc)

View File

@@ -172,7 +172,7 @@ def convert_sf_object_to_doc(
sections = [_extract_section(sf_object, base_url)]
for id in sf_db.get_child_ids(sf_object.id):
if not (child_object := sf_db.get_record(id)):
if not (child_object := sf_db.get_record(id, isChild=True)):
continue
sections.append(_extract_section(child_object, base_url))

View File

@@ -456,7 +456,7 @@ class OnyxSalesforceSQLite:
return result[0]
def get_record(
self, object_id: str, object_type: str | None = None
self, object_id: str, object_type: str | None = None, isChild: bool = False
) -> SalesforceObject | None:
"""Retrieve the record and return it as a SalesforceObject."""
if self._conn is None:
@@ -469,15 +469,44 @@ class OnyxSalesforceSQLite:
with self._conn:
cursor = self._conn.cursor()
cursor.execute(
"SELECT data FROM salesforce_objects WHERE id = ?", (object_id,)
)
result = cursor.fetchone()
# Get the object data and account data
if object_type == "Account" or isChild:
cursor.execute(
"SELECT data FROM salesforce_objects WHERE id = ?", (object_id,)
)
else:
cursor.execute(
"SELECT pso.data, r.parent_id as parent_id, sso.object_type FROM salesforce_objects pso \
LEFT JOIN relationships r on r.child_id = pso.id \
LEFT JOIN salesforce_objects sso on r.parent_id = sso.id \
WHERE pso.id = ? ",
(object_id,),
)
result = cursor.fetchall()
if not result:
logger.warning(f"Object ID {object_id} not found")
return None
data = json.loads(result[0])
data = json.loads(result[0][0])
if object_type != "Account":
# convert any account ids of the relationships back into data fields, with name
for row in result:
# the following skips Account objects.
if len(row) < 3:
continue
if row[1] and row[2] and row[2] == "Account":
data["AccountId"] = row[1]
cursor.execute(
"SELECT data FROM salesforce_objects WHERE id = ?",
(row[1],),
)
account_data = json.loads(cursor.fetchone()[0])
data["Account"] = account_data.get("Name", "")
return SalesforceObject(id=object_id, type=object_type, data=data)
def find_ids_by_type(self, object_type: str) -> list[str]:

View File

@@ -113,6 +113,11 @@ class BaseFilters(BaseModel):
tags: list[Tag] | None = None
user_file_ids: list[int] | None = None
user_folder_ids: list[int] | None = None
kg_entities: list[str] | None = None
kg_relationships: list[str] | None = None
kg_terms: list[str] | None = None
kg_sources: list[str] | None = None
kg_chunk_id_zero_only: bool | None = False
class IndexFilters(BaseFilters):

View File

@@ -182,6 +182,11 @@ def retrieval_preprocessing(
tags=preset_filters.tags, # Tags are never auto-extracted
access_control_list=user_acl_filters,
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
kg_entities=preset_filters.kg_entities,
kg_relationships=preset_filters.kg_relationships,
kg_terms=preset_filters.kg_terms,
kg_sources=preset_filters.kg_sources,
kg_chunk_id_zero_only=preset_filters.kg_chunk_id_zero_only,
)
llm_evaluation_type = LLMEvaluationType.BASIC

View File

@@ -16,6 +16,7 @@ from onyx.db.enums import IndexingMode
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.kg.models import KGConnectorData
from onyx.server.documents.models import ConnectorBase
from onyx.server.documents.models import ObjectCreationIdResponse
from onyx.server.models import StatusResponse
@@ -334,3 +335,29 @@ def mark_ccpair_with_indexing_trigger(
except Exception:
db_session.rollback()
raise
def get_kg_enabled_connectors(db_session: Session) -> list[KGConnectorData]:
"""
Retrieves a list of connector IDs that have not been KG processed for a given tenant.
Args:
db_session (Session): The database session to use
Returns:
list[int]: List of connector IDs that have enabled KG extraction but have unprocessed documents
"""
try:
stmt = select(Connector.id, Connector.source).where(
Connector.kg_processing_enabled
)
result = db_session.execute(stmt)
connector_results = [
KGConnectorData(id=row[0], source=row[1].lower())
for row in result.fetchall()
]
return connector_results
except Exception as e:
logger.error(f"Error fetching unprocessed connector IDs: {str(e)}")
raise e

View File

@@ -21,23 +21,33 @@ from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import null
from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.configs.kg_configs import KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
from onyx.db.entities import delete_from_kg_entities__no_commit
from onyx.db.entities import delete_from_kg_entities_extraction_staging__no_commit
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Document
from onyx.db.models import Document as DbDocument
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import User
from onyx.db.relationships import delete_from_kg_relationships__no_commit
from onyx.db.relationships import (
delete_from_kg_relationships_extraction_staging__no_commit,
)
from onyx.db.tag import delete_document_tags_for_documents__no_commit
from onyx.db.utils import model_to_dict
from onyx.document_index.interfaces import DocumentMetadata
from onyx.kg.models import KGStage
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from onyx.utils.logger import setup_logger
@@ -392,6 +402,7 @@ def upsert_documents(
last_modified=datetime.now(timezone.utc),
primary_owners=doc.primary_owners,
secondary_owners=doc.secondary_owners,
kg_stage=KGStage.NOT_STARTED,
)
)
for doc in seen_documents.values()
@@ -605,7 +616,29 @@ def delete_documents_complete__no_commit(
) -> None:
"""This completely deletes the documents from the db, including all foreign key relationships"""
# Start by deleting the chunk stats for the documents
# Start with the kg references
delete_from_kg_relationships__no_commit(
db_session=db_session,
document_ids=document_ids,
)
delete_from_kg_entities__no_commit(
db_session=db_session,
document_ids=document_ids,
)
delete_from_kg_relationships_extraction_staging__no_commit(
db_session=db_session,
document_ids=document_ids,
)
delete_from_kg_entities_extraction_staging__no_commit(
db_session=db_session,
document_ids=document_ids,
)
# Continue with deleting the chunk stats for the documents
delete_chunk_stats_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids,
@@ -858,3 +891,262 @@ def fetch_chunk_count_for_document(
) -> int | None:
stmt = select(DbDocument.chunk_count).where(DbDocument.id == document_id)
return db_session.execute(stmt).scalar_one_or_none()
def get_unprocessed_kg_document_batch_for_connector(
db_session: Session,
connector_id: int,
batch_size: int = 100,
) -> list[DbDocument]:
"""
Retrieves a batch of documents that have not been processed for knowledge graph extraction.
Args:
db_session (Session): The database session to use
connector_id (int): The ID of the connector to get documents for
batch_size (int): The maximum number of documents to retrieve
Returns:
list[DbDocument]: List of documents that need KG processing
"""
stmt = (
select(DbDocument)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
or_(
DbDocument.kg_stage.is_(None),
DbDocument.kg_stage == KGStage.NOT_STARTED,
),
)
)
.distinct()
.order_by(DbDocument.doc_updated_at.desc())
.limit(batch_size)
)
documents = db_session.scalars(stmt).all()
db_session.flush()
return list(documents)
def get_kg_extracted_document_ids(db_session: Session) -> list[str]:
"""
Retrieves all document IDs where kg_stage is EXTRACTED.
Args:
db_session (Session): The database session to use
Returns:
list[str]: List of document IDs that have been KG processed
"""
stmt = select(DbDocument.id).where(DbDocument.kg_stage == KGStage.EXTRACTED)
return list(db_session.scalars(stmt).all())
def update_document_kg_info(
db_session: Session, document_id: str, kg_stage: KGStage
) -> None:
"""Updates the knowledge graph related information for a document.
Args:
db_session (Session): The database session to use
document_id (str): The ID of the document to update
kg_stage (KGStage): The stage of the knowledge graph processing for the document
Raises:
ValueError: If the document with the given ID is not found
"""
stmt = (
update(DbDocument)
.where(DbDocument.id == document_id)
.values(
kg_stage=kg_stage,
)
)
db_session.execute(stmt)
def update_document_kg_stage(
db_session: Session,
document_id: str,
kg_stage: KGStage,
) -> None:
stmt = (
update(DbDocument).where(DbDocument.id == document_id).values(kg_stage=kg_stage)
)
db_session.execute(stmt)
db_session.flush()
def get_document_kg_info(
db_session: Session,
document_id: str,
) -> KGStage | None:
"""Retrieves the knowledge graph processing status and data for a document.
Args:
db_session (Session): The database session to use
document_id (str): The ID of the document to query
Returns:
Optional[Tuple[bool, dict]]: A tuple containing:
- bool: Whether the document has been KG processed
- dict: The KG data containing 'entities', 'relationships', and 'terms'
Returns None if the document is not found
"""
stmt = select(DbDocument.kg_stage).where(DbDocument.id == document_id)
result = db_session.execute(stmt).one_or_none()
if result is None:
return None
return result.kg_stage or KGStage.NOT_STARTED
def get_all_kg_extracted_documents_info(
db_session: Session,
) -> list[str]:
"""Retrieves the knowledge graph data for all documents that have been processed.
Args:
db_session (Session): The database session to use
Returns:
List[Tuple[str, dict]]: A list of tuples containing:
- str: The document ID
- dict: The KG data containing 'entities', 'relationships', and 'terms'
Only returns documents where kg_stage is EXTRACTED
"""
stmt = (
select(DbDocument.id)
.where(DbDocument.kg_stage == KGStage.EXTRACTED)
.order_by(DbDocument.id)
)
results = db_session.execute(stmt).all()
return [str(doc_id) for doc_id in results]
def get_base_llm_doc_information(
db_session: Session, document_ids: list[str]
) -> list[str]:
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
results = db_session.execute(stmt).all()
documents = []
for doc_nr, doc in enumerate(results):
bare_doc = doc[0]
documents.append(
f"""* [{bare_doc.semantic_id}]({bare_doc.link}) ({bare_doc.doc_updated_at})"""
)
return documents[:KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES]
def get_document_updated_at(
document_id: str,
db_session: Session,
) -> datetime | None:
"""Retrieves the doc_updated_at timestamp for a given document ID.
Args:
document_id (str): The ID of the document to query
db_session (Session): The database session to use
Returns:
Optional[datetime]: The doc_updated_at timestamp if found, None if document doesn't exist
"""
if len(document_id.split(":")) == 2:
document_id = document_id.split(":")[1]
elif len(document_id.split(":")) > 2:
raise ValueError(f"Invalid document ID: {document_id}")
else:
pass
stmt = select(DbDocument.doc_updated_at).where(DbDocument.id == document_id)
return db_session.execute(stmt).scalar_one_or_none()
def reset_all_document_kg_stages(db_session: Session) -> int:
"""Reset the KG stage of all documents that are not in NOT_STARTED state to NOT_STARTED.
Args:
db_session (Session): The database session to use
Returns:
int: Number of documents that were reset
"""
stmt = (
update(DbDocument)
.where(DbDocument.kg_stage != KGStage.NOT_STARTED)
.values(kg_stage=KGStage.NOT_STARTED)
)
result = db_session.execute(stmt)
# The hasattr check is needed for type checking, even though rowcount
# is guaranteed to exist at runtime for UPDATE operations
return result.rowcount if hasattr(result, "rowcount") else 0
def update_document_kg_stages(
db_session: Session, source_stage: KGStage, target_stage: KGStage
) -> int:
"""Reset the KG stage only of documents back to NOT_STARTED.
Part of reset flow for documemnts that have been extracted but not clustered.
Args:
db_session (Session): The database session to use
Returns:
int: Number of documents that were reset
"""
stmt = (
update(DbDocument)
.where(DbDocument.kg_stage == source_stage)
.values(kg_stage=target_stage)
)
result = db_session.execute(stmt)
# The hasattr check is needed for type checking, even though rowcount
# is guaranteed to exist at runtime for UPDATE operations
return result.rowcount if hasattr(result, "rowcount") else 0
def get_skipped_kg_documents(db_session: Session) -> list[str]:
"""
Retrieves all document IDs where kg_stage is SKIPPED.
Args:
db_session (Session): The database session to use
Returns:
list[str]: List of document IDs that have been skipped in KG processing
"""
stmt = select(DbDocument.id).where(DbDocument.kg_stage == KGStage.SKIPPED)
return list(db_session.scalars(stmt).all())
def get_kg_doc_info_for_entity_name(
db_session: Session, document_id: str, entity_type: str
) -> KGEntityDocInfo:
"""
Get the semantic ID and the link for an entity name.
"""
result = (
db_session.query(Document.semantic_id, Document.link)
.filter(Document.id == document_id)
.first()
)
if result is None:
return KGEntityDocInfo(
doc_id=None,
doc_semantic_id=None,
doc_link=None,
semantic_entity_name=f"{entity_type}:{document_id}",
semantic_linked_entity_name=f"{entity_type}:{document_id}",
)
return KGEntityDocInfo(
doc_id=document_id,
doc_semantic_id=result[0],
doc_link=result[1],
semantic_entity_name=f"{entity_type.upper()}:{result.semantic_id}",
semantic_linked_entity_name=f"[{entity_type.upper()}:{result.semantic_id}]({result[1]})",
)

View File

@@ -27,6 +27,8 @@ from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from onyx.configs.app_configs import AWS_REGION_NAME
from onyx.configs.app_configs import DB_READONLY_PASSWORD
from onyx.configs.app_configs import DB_READONLY_USER
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
@@ -187,7 +189,9 @@ def is_valid_schema_name(name: str) -> bool:
class SqlEngine:
_engine: Engine | None = None
_readonly_engine: Engine | None = None
_lock: threading.Lock = threading.Lock()
_readonly_lock: threading.Lock = threading.Lock()
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
@classmethod
@@ -252,12 +256,80 @@ class SqlEngine:
cls._engine = engine
@classmethod
def init_readonly_engine(
cls,
pool_size: int,
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
max_overflow: int,
**extra_engine_kwargs: Any,
) -> None:
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
important args, and if incorrectly specified, we have run into hitting the pool
limit / using too many connections and overwhelming the database."""
with cls._readonly_lock:
if cls._readonly_engine:
return
if not DB_READONLY_USER or not DB_READONLY_PASSWORD:
raise ValueError(
"Custom database user credentials not configured in environment variables"
)
# Build connection string with custom user
connection_string = build_connection_string(
user=DB_READONLY_USER,
password=DB_READONLY_PASSWORD,
use_iam_auth=False, # Custom users typically don't use IAM auth
db_api=SYNC_DB_API, # Explicitly use sync DB API
)
# Start with base kwargs that are valid for all pool types
final_engine_kwargs: dict[str, Any] = {}
if POSTGRES_USE_NULL_POOL:
# if null pool is specified, then we need to make sure that
# we remove any passed in kwargs related to pool size that would
# cause the initialization to fail
final_engine_kwargs.update(extra_engine_kwargs)
final_engine_kwargs["poolclass"] = pool.NullPool
if "pool_size" in final_engine_kwargs:
del final_engine_kwargs["pool_size"]
if "max_overflow" in final_engine_kwargs:
del final_engine_kwargs["max_overflow"]
else:
final_engine_kwargs["pool_size"] = pool_size
final_engine_kwargs["max_overflow"] = max_overflow
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
# any passed in kwargs override the defaults
final_engine_kwargs.update(extra_engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
# echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:
event.listen(engine, "do_connect", provide_iam_token)
cls._readonly_engine = engine
@classmethod
def get_engine(cls) -> Engine:
if not cls._engine:
raise RuntimeError("Engine not initialized. Must call init_engine first.")
return cls._engine
@classmethod
def get_readonly_engine(cls) -> Engine:
if not cls._readonly_engine:
raise RuntimeError(
"Readonly engine not initialized. Must call init_readonly_engine first."
)
return cls._readonly_engine
@classmethod
def set_app_name(cls, app_name: str) -> None:
cls._app_name = app_name
@@ -307,6 +379,10 @@ def get_sqlalchemy_engine() -> Engine:
return SqlEngine.get_engine()
def get_readonly_sqlalchemy_engine() -> Engine:
return SqlEngine.get_readonly_engine()
async def get_async_connection() -> Any:
"""
Custom connection function for async engine when using IAM auth.
@@ -560,3 +636,46 @@ def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) ->
region = os.getenv("AWS_REGION_NAME", "us-east-2")
# Configure for psycopg2 with IAM token
configure_psycopg2_iam_auth(cparams, host, port, user, region)
@contextmanager
def get_db_readonly_user_session_with_current_tenant() -> (
Generator[Session, None, None]
):
"""
Generate a database session using a custom database user for the current tenant.
The custom user credentials are obtained from environment variables.
"""
tenant_id = get_current_tenant_id()
readonly_engine = get_readonly_sqlalchemy_engine()
event.listen(readonly_engine, "checkout", _set_search_path_on_checkout__listener)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
with readonly_engine.connect() as connection:
dbapi_connection = connection.connection
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
finally:
cursor.close()
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()

329
backend/onyx/db/entities.py Normal file
View File

@@ -0,0 +1,329 @@
from datetime import datetime
from datetime import timezone
from typing import cast
from typing import List
from typing import Type
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session
from onyx.db.models import Document
from onyx.db.models import KGEntity
from onyx.db.models import KGEntityExtractionStaging
from onyx.db.models import KGEntityType
from onyx.kg.models import KGGroundingType
from onyx.kg.models import KGStage
def add_entity(
db_session: Session,
kg_stage: KGStage,
entity_type: str,
name: str,
document_id: str | None = None,
occurrences: int = 0,
event_time: datetime | None = None,
attributes: dict[str, str] | None = None,
) -> "KGEntity | KGEntityExtractionStaging | None":
"""Add a new entity to the database.
Args:
db_session: SQLAlchemy session
kg_stage: KGStage of the entity
entity_type: Type of the entity (must match an existing KGEntityType)
name: Name of the entity
occurrences: Number of clusters this entity has been found
Returns:
KGEntity: The created entity
"""
entity_type = entity_type.upper()
name = name.title()
id_name = f"{entity_type}:{name}"
_KGEntityObject: Type[KGEntity | KGEntityExtractionStaging]
if kg_stage == KGStage.EXTRACTED:
_KGEntityObject = KGEntityExtractionStaging
elif kg_stage == KGStage.NORMALIZED:
_KGEntityObject = KGEntity
else:
raise ValueError(f"Invalid KGStage: {kg_stage}")
# Create new entity
stmt = (
pg_insert(_KGEntityObject)
.values(
id_name=id_name,
entity_type_id_name=entity_type,
document_id=document_id,
name=name,
occurrences=occurrences,
event_time=event_time,
attributes=attributes,
)
.on_conflict_do_update(
index_elements=["id_name"],
set_=dict(
# Direct numeric addition without text()
occurrences=_KGEntityObject.occurrences + occurrences,
# Keep other fields updated as before
entity_type_id_name=entity_type,
document_id=document_id,
name=name,
event_time=event_time,
attributes=attributes,
),
)
.returning(_KGEntityObject)
)
result = db_session.execute(stmt).scalar()
# Update the document's kg_stage if document_id is provided
if document_id is not None:
db_session.query(Document).filter(Document.id == document_id).update(
{"kg_stage": kg_stage, "kg_processing_time": datetime.now(timezone.utc)}
)
db_session.flush()
return result
def get_kg_entity_by_document(db: Session, document_id: str) -> KGEntity | None:
"""
Check if a document_id exists in the kg_entities table and return its id_name if found.
Args:
db: SQLAlchemy database session
document_id: The document ID to search for
Returns:
The id_name of the matching KGEntity if found, None otherwise
"""
query = select(KGEntity).where(KGEntity.document_id == document_id)
result = db.execute(query).scalar()
return result
def get_entities_by_grounding(
db_session: Session, kg_stage: KGStage, grounding: KGGroundingType
) -> List[KGEntity] | List[KGEntityExtractionStaging]:
"""Get all entities by grounding type.
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntity objects for a given grounding type
"""
_KGEntityObject: Type[KGEntity | KGEntityExtractionStaging]
if kg_stage not in [KGStage.EXTRACTED, KGStage.NORMALIZED]:
raise ValueError(f"Invalid KGStage: {kg_stage}")
if kg_stage == KGStage.EXTRACTED:
_KGEntityObject = KGEntityExtractionStaging
elif kg_stage == KGStage.NORMALIZED:
_KGEntityObject = KGEntity
result = list(
db_session.query(_KGEntityObject)
.join(
KGEntityType,
_KGEntityObject.entity_type_id_name == KGEntityType.id_name,
)
.filter(KGEntityType.grounding == grounding)
.all()
)
if kg_stage == KGStage.EXTRACTED:
return cast(List[KGEntityExtractionStaging], result)
else:
return cast(List[KGEntity], result)
def get_grounded_entities_by_types(
db_session: Session, entity_types: List[str], grounding: KGGroundingType
) -> List[KGEntity]:
"""Get all entities matching an entity_type.
Args:
db_session: SQLAlchemy session
entity_types: List of entity types to filter by
Returns:
List of KGEntity objects belonging to the specified entity types
"""
return (
db_session.query(KGEntity)
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
.filter(KGEntity.entity_type_id_name.in_(entity_types))
.filter(KGEntityType.grounding == grounding)
.all()
)
def delete_entities_by_id_names(
db_session: Session, id_names: list[str], kg_stage: KGStage
) -> int:
"""
Delete entities from the database based on a list of id_names.
Args:
db_session: SQLAlchemy database session
id_names: List of entity id_names to delete
Returns:
Number of entities deleted
"""
if kg_stage not in [KGStage.EXTRACTED, KGStage.NORMALIZED]:
raise ValueError(f"Invalid KGStage: {kg_stage}")
if kg_stage == KGStage.EXTRACTED:
_KGEntityObject: Type[KGEntity | KGEntityExtractionStaging] = (
KGEntityExtractionStaging
)
elif kg_stage == KGStage.NORMALIZED:
_KGEntityObject = KGEntity
else:
raise ValueError(f"Invalid KGStage: {kg_stage}")
deleted_count = (
db_session.query(_KGEntityObject)
.filter(_KGEntityObject.id_name.in_(id_names))
.delete(synchronize_session=False)
)
db_session.flush() # Flush to ensure deletion is processed
return deleted_count
def get_entity_names_for_types(
db_session: Session, entity_types: List[str]
) -> List[tuple[str, str | None]]:
"""Get all entities that belong to the specified entity types.
Args:
db_session: SQLAlchemy session
entity_types: List of entity type id_names to filter by
Returns:
List of entity id_names belonging to the specified entity types
"""
entity_query = db_session.query(KGEntity).filter(
KGEntity.entity_type_id_name.in_(entity_types)
)
# Get document IDs from the filtered entities
doc_ids = [e.document_id for e in entity_query.all() if e.document_id is not None]
# Get document info for those IDs
doc_info: dict[str, tuple[str | None, str | None]] = {
row[0].capitalize(): (row[1], row[2])
for row in db_session.query(Document.id, Document.semantic_id, Document.link)
.filter(Document.id.in_(doc_ids))
.all()
}
# Return entities with their document info
names: list[tuple[str, str | None]] = []
for entity in entity_query.all():
if entity.document_id is None:
names.append((entity.id_name, entity.id_name))
continue
# Extract entity type from the full type ID
entity_type = entity.entity_type_id_name.split(":")[0].upper()
# Get document info, defaulting to None if not found
doc_semantic_id = doc_info.get(entity.document_id.capitalize(), (None, None))[0]
# Construct the final string
names.append((entity.id_name, f"{entity_type}:{doc_semantic_id}"))
return names
def get_entities_by_document_ids(
db_session: Session, document_ids: list[str], kg_stage: KGStage
) -> List[str]:
"""Get all entity id_names that belong to the specified document ids.
Args:
db_session: SQLAlchemy database session
document_ids: List of document ids to filter by
Returns:
List of entity id_names belonging to the specified document ids
"""
if kg_stage == KGStage.EXTRACTED:
stmt = select(KGEntityExtractionStaging.id_name).where(
func.lower(KGEntityExtractionStaging.document_id).in_(document_ids)
)
elif kg_stage == KGStage.NORMALIZED:
stmt = select(KGEntity.id_name).where(
func.lower(KGEntity.document_id).in_(document_ids)
)
else:
raise ValueError(f"Invalid KGStage: {kg_stage.value}")
result = db_session.execute(stmt).scalars().all()
return list(result)
def get_document_id_for_entity(
db_session: Session, entity: str, kg_stage: KGStage = KGStage.NORMALIZED
) -> str | None:
"""Get the document ID associated with an entity.
Args:
db_session: SQLAlchemy database session
entity: The entity id_name to look up
kg_stage: The knowledge graph stage to search in (defaults to NORMALIZED)
Returns:
The document ID if found, None otherwise
"""
entity = entity.replace(": ", ":")
if kg_stage == KGStage.EXTRACTED:
_KGEntityObject: Type[KGEntity | KGEntityExtractionStaging] = (
KGEntityExtractionStaging
)
elif kg_stage == KGStage.NORMALIZED:
_KGEntityObject = KGEntity
else:
raise ValueError(f"Invalid KGStage: {kg_stage}")
stmt = select(_KGEntityObject.document_id).where(
func.lower(_KGEntityObject.id_name) == func.lower(entity)
)
result = db_session.execute(stmt).scalars().first()
return result
def delete_from_kg_entities_extraction_staging__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""Delete entities from the extraction staging table."""
db_session.query(KGEntityExtractionStaging).filter(
KGEntityExtractionStaging.document_id.in_(document_ids)
).delete(synchronize_session=False)
def delete_from_kg_entities__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""Delete entities from the normalized table."""
db_session.query(KGEntity).filter(KGEntity.document_id.in_(document_ids)).delete(
synchronize_session=False
)

View File

@@ -0,0 +1,257 @@
from typing import List
from sqlalchemy.orm import Session
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.models import KGEntityType
from onyx.kg.kg_default_entity_definitions import KGDefaultAccountEmployeeDefinitions
from onyx.kg.kg_default_entity_definitions import (
KGDefaultPrimaryGroundedEntityDefinitions,
)
from onyx.kg.models import KGGroundingType
def get_determined_grounded_entity_types(db_session: Session) -> List[KGEntityType]:
"""Get all entity types that have non-null entity_values.
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntityType objects that have entity_values defined
"""
return (
db_session.query(KGEntityType)
.filter(KGEntityType.entity_values.isnot(None))
.all()
)
def get_grounded_entity_types(db_session: Session) -> List[KGEntityType]:
"""Get all entity types that have grounding = GROUNDED.
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntityType objects that have grounding = GROUNDED
"""
return (
db_session.query(KGEntityType)
.filter(KGEntityType.grounding == KGGroundingType.GROUNDED)
.all()
)
def get_entity_types_with_grounded_source_name(
db_session: Session,
) -> List[KGEntityType]:
"""Get all entity types that have non-null grounded_source_name.
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntityType objects that have grounded_source_name defined
"""
return (
db_session.query(KGEntityType)
.filter(KGEntityType.grounded_source_name.isnot(None))
.all()
)
def get_entity_type_by_grounded_source_name(
db_session: Session, grounded_source_name: KGGroundingType
) -> KGEntityType | None:
"""Get an entity type by its grounded_source_name and return it as a dictionary.
Args:
db_session: SQLAlchemy session
grounded_source_name: The grounded_source_name of the entity to retrieve
Returns:
Dictionary containing the entity's data with column names as keys,
or None if the entity is not found
"""
entity_type = (
db_session.query(KGEntityType)
.filter(KGEntityType.grounded_source_name == grounded_source_name)
.first()
)
if entity_type is None:
return None
return entity_type
def get_entity_types(
db_session: Session,
active: bool | None = True,
) -> list[KGEntityType]:
# Query the database for all distinct entity types
if active is None:
return db_session.query(KGEntityType).order_by(KGEntityType.id_name).all()
else:
return (
db_session.query(KGEntityType)
.filter(KGEntityType.active == active)
.order_by(KGEntityType.id_name)
.all()
)
def populate_default_primary_grounded_entity_type_information(
db_session: Session,
) -> None:
"""Populate the entity type information for the KG.
Args:
db_session: SQLAlchemy session
"""
# get kg config information
kg_config_settings = get_kg_config_settings(db_session)
if not kg_config_settings.KG_ENABLED:
raise ValueError("KG is not enabled")
if not kg_config_settings.KG_VENDOR:
raise ValueError("KG_VENDOR is not set")
if not kg_config_settings.KG_VENDOR_DOMAINS:
raise ValueError("KG_VENDOR_DOMAINS is not set")
# Get all existing entity types
existing_entity_types = {et.id_name for et in db_session.query(KGEntityType).all()}
# Create an instance of the default definitions
default_definitions = KGDefaultPrimaryGroundedEntityDefinitions()
# Iterate over all attributes in the default definitions
for id_name, definition in default_definitions.model_dump().items():
# Skip if this entity type already exists
if id_name in existing_entity_types:
continue
# Create new entity type
description = definition["description"].replace(
"---vendor_name---", kg_config_settings.KG_VENDOR
)
new_entity_type = KGEntityType(
id_name=id_name,
description=description,
grounding=definition["grounding"],
grounded_source_name=definition["grounded_source_name"],
active=False,
)
# Add to session
db_session.add(new_entity_type)
# Commit changes
db_session.flush()
def populate_default_employee_account_information(db_session: Session) -> None:
"""Populate the entity type information for the KG.
Args:
db_session: SQLAlchemy session
"""
# get kg config information
kg_config_settings = get_kg_config_settings(db_session)
if not kg_config_settings.KG_ENABLED:
raise ValueError("KG is not enabled")
if not kg_config_settings.KG_VENDOR:
raise ValueError("KG_VENDOR is not set")
if not kg_config_settings.KG_VENDOR_DOMAINS:
raise ValueError("KG_VENDOR_DOMAINS is not set")
# Get all existing entity types
existing_entity_types = {et.id_name for et in db_session.query(KGEntityType).all()}
# Create an instance of the default definitions
default_definitions = KGDefaultAccountEmployeeDefinitions()
# Iterate over all attributes in the default definitions
for id_name, definition in default_definitions.model_dump().items():
# Skip if this entity type already exists
if id_name in existing_entity_types:
continue
# Create new entity type
description = definition["description"].replace(
"---vendor_name---", kg_config_settings.KG_VENDOR
)
new_entity_type = KGEntityType(
id_name=id_name,
description=description,
grounding=definition["grounding"],
grounded_source_name=definition["grounded_source_name"],
active=definition["active"],
)
# Add to session
db_session.add(new_entity_type)
# Commit changes
db_session.flush()
def get_grounded_entity_types_with_null_grounded_source(
db_session: Session,
) -> List[KGEntityType]:
"""Get all entity types that have null grounded_source_name and grounding = GROUNDED.
Args:
db_session: SQLAlchemy session
Returns:
List of KGEntityType objects that have null grounded_source_name and grounding = GROUNDED
"""
return (
db_session.query(KGEntityType)
.filter(KGEntityType.grounded_source_name.is_(None))
.filter(KGEntityType.grounding == KGGroundingType.GROUNDED)
.all()
)
def get_entity_types_by_grounding(
db_session: Session,
grounding: KGGroundingType,
) -> List[KGEntityType]:
"""Get all entity types that have a specific grounding.
Args:
db_session: SQLAlchemy session
grounding: The grounding type to filter by
Returns:
List of KGEntityType objects that have the specified grounding
"""
return (
db_session.query(KGEntityType).filter(KGEntityType.grounding == grounding).all()
)
def get_grounded_source_name(db_session: Session, entity_type: str) -> str | None:
"""
Get the grounded source name for an entity type.
"""
result = (
db_session.query(KGEntityType)
.filter(KGEntityType.id_name == entity_type)
.first()
)
if result is None:
return None
return result.grounded_source_name

View File

@@ -0,0 +1,34 @@
from sqlalchemy.orm import Session
from onyx.db.models import KGConfig
from onyx.kg.models import KGConfigSettings
from onyx.kg.models import KGConfigVars
def get_kg_enablement(db_session: Session) -> bool:
check = (
db_session.query(KGConfig.kg_variable_values)
.filter(
KGConfig.kg_variable_name == "KG_ENABLED"
and KGConfig.kg_variable_values == ["true"]
)
.first()
)
return check is not None
def get_kg_config_settings(db_session: Session) -> KGConfigSettings:
results = db_session.query(KGConfig).all()
kg_config_settings = KGConfigSettings()
for result in results:
if result.kg_variable_name == "KG_ENABLED":
kg_config_settings.KG_ENABLED = result.kg_variable_values[0] == "true"
elif result.kg_variable_name == KGConfigVars.KG_VENDOR:
kg_config_settings.KG_VENDOR = result.kg_variable_values[0]
elif result.kg_variable_name == KGConfigVars.KG_VENDOR_DOMAINS:
kg_config_settings.KG_VENDOR_DOMAINS = result.kg_variable_values
elif result.kg_variable_name == KGConfigVars.KG_IGNORE_EMAIL_DOMAINS:
kg_config_settings.KG_IGNORE_EMAIL_DOMAINS = result.kg_variable_values
return kg_config_settings

View File

@@ -0,0 +1,131 @@
from sqlalchemy import text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DB_READONLY_USER
Base = declarative_base()
# First, create the view definition
def create_views(
db_session: Session,
user_email: str,
allowed_docs_view_name: str = "allowed_docs",
kg_relationships_view_name: str = "kg_relationships_with_access",
) -> None:
# Create ALLOWED_DOCS view
allowed_docs_view = text(
f"""
CREATE OR REPLACE VIEW {allowed_docs_view_name} AS
WITH public_docs AS (
SELECT d.id as allowed_doc_id
FROM document_by_connector_credential_pair d
JOIN connector_credential_pair ccp ON
d.connector_id = ccp.connector_id AND
d.credential_id = ccp.credential_id
WHERE ccp.status != 'DELETING'
AND ccp.access_type = 'PUBLIC'
),
user_owned_docs AS (
SELECT d.id as allowed_doc_id
FROM document_by_connector_credential_pair d
JOIN credential c ON d.credential_id = c.id
JOIN connector_credential_pair ccp ON
d.connector_id = ccp.connector_id AND
d.credential_id = ccp.credential_id
JOIN "user" u ON c.user_id = u.id
WHERE ccp.status != 'DELETING'
AND ccp.access_type != 'SYNC'
AND u.email = :user_email
),
external_user_docs AS (
SELECT id as allowed_doc_id
FROM document
WHERE :user_email = ANY(external_user_emails)
),
external_group_docs AS (
SELECT d.id as allowed_doc_id
FROM document d
JOIN user__external_user_group_id ueg ON ueg.external_user_group_id = ANY(d.external_user_group_ids)
JOIN "user" u ON ueg.user_id = u.id
WHERE u.email = :user_email
)
SELECT DISTINCT allowed_doc_id FROM (
SELECT allowed_doc_id FROM public_docs
UNION
SELECT allowed_doc_id FROM user_owned_docs
UNION
SELECT allowed_doc_id FROM external_user_docs
UNION
SELECT allowed_doc_id FROM external_group_docs
) combined_docs
"""
).bindparams(user_email=user_email)
# Create the main view that uses ALLOWED_DOCS
kg_relationships_view = text(
f"""
CREATE OR REPLACE VIEW {kg_relationships_view_name} AS
SELECT kgr.id_name as relationship,
kgr.source_node as source_entity,
kgr.target_node as target_entity,
kgr.source_node_type as source_entity_type,
kgr.target_node_type as target_entity_type,
kgr.type as relationship_description,
kgr.relationship_type_id_name as relationship_type,
kgr.source_document as source_document,
d.doc_updated_at as source_date,
se.attributes as source_entity_attributes,
te.attributes as target_entity_attributes
FROM kg_relationship kgr
INNER JOIN {allowed_docs_view_name} AD on AD.allowed_doc_id = kgr.source_document
JOIN document d on d.id = kgr.source_document
JOIN kg_entity se on se.id_name = kgr.source_node
JOIN kg_entity te on te.id_name = kgr.target_node
"""
)
# Execute the views using the session
db_session.execute(allowed_docs_view)
db_session.execute(kg_relationships_view)
grant_kg_relationships = text(
f"GRANT SELECT ON {kg_relationships_view_name} TO {DB_READONLY_USER}"
)
db_session.execute(grant_kg_relationships)
db_session.commit()
return None
def drop_views(
db_session: Session,
allowed_docs_view_name: str = "allowed_docs",
kg_relationships_view_name: str = "kg_relationships_with_access",
) -> None:
"""
Drops the temporary views created by create_views.
Args:
db_session: SQLAlchemy session
allowed_docs_view_name: Name of the allowed_docs view
kg_relationships_view_name: Name of the kg_relationships view
"""
# First revoke access from the readonly user
revoke_kg_relationships = text(
f"REVOKE SELECT ON {kg_relationships_view_name} FROM {DB_READONLY_USER}"
)
db_session.execute(revoke_kg_relationships)
# Drop the views in reverse order of creation to handle dependencies
drop_kg_relationships = text(f"DROP VIEW IF EXISTS {kg_relationships_view_name}")
drop_allowed_docs = text(f"DROP VIEW IF EXISTS {allowed_docs_view_name}")
db_session.execute(drop_kg_relationships)
db_session.execute(drop_allowed_docs)
db_session.commit()
return None

View File

@@ -39,6 +39,7 @@ from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.types import LargeBinary
from sqlalchemy.types import TypeDecorator
from sqlalchemy import PrimaryKeyConstraint
from onyx.auth.schemas import UserRole
from onyx.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
@@ -69,6 +70,7 @@ from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.context.search.enums import RecencyBiasSetting
from onyx.kg.models import KGStage
from onyx.utils.encryption import decrypt_bytes_to_string
from onyx.utils.encryption import encrypt_string_to_bytes
from onyx.utils.headers import HeaderItemDict
@@ -586,6 +588,16 @@ class Document(Base):
)
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
# tables for the knowledge graph data
kg_stage: Mapped[KGStage] = mapped_column(
Enum(KGStage, native_enum=False),
comment="Status of knowledge graph extraction for this document",
)
kg_processing_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
"DocumentRetrievalFeedback", back_populates="document"
)
@@ -604,12 +616,589 @@ class Document(Base):
)
class KGConfig(Base):
__tablename__ = "kg_config"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
kg_variable_name: Mapped[str] = mapped_column(NullFilteredString, nullable=False)
kg_variable_values: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
class KGEntityType(Base):
__tablename__ = "kg_entity_type"
# Primary identifier
id_name: Mapped[str] = mapped_column(
String, primary_key=True, nullable=False, index=True
)
description: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
grounding: Mapped[str] = mapped_column(
NullFilteredString, nullable=False, index=False
)
attributes: Mapped[str] = mapped_column(
postgresql.JSONB,
nullable=True,
default=dict,
server_default="{}",
comment="Filtering based on document attribute",
)
occurrences: Mapped[int | None] = mapped_column(Integer, nullable=True)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
deep_extraction: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
grounded_source_name: Mapped[str] = mapped_column(
NullFilteredString, nullable=False, index=False
)
entity_values: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=True, default=None
)
clustering: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Clustering information for this entity type",
)
class KGRelationshipType(Base):
__tablename__ = "kg_relationship_type"
# Primary identifier
id_name: Mapped[str] = mapped_column(
NullFilteredString,
primary_key=True,
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
source_entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
target_entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
definition: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
comment="Whether this relationship type represents a definition",
)
clustering: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Clustering information for this relationship type",
)
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
occurrences: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
# Relationships to EntityType
source_type: Mapped["KGEntityType"] = relationship(
"KGEntityType",
foreign_keys=[source_entity_type_id_name],
backref="source_relationship_type",
)
target_type: Mapped["KGEntityType"] = relationship(
"KGEntityType",
foreign_keys=[target_entity_type_id_name],
backref="target_relationship_type",
)
class KGRelationshipTypeExtractionStaging(Base):
__tablename__ = "kg_relationship_type_extraction_staging"
# Primary identifier
id_name: Mapped[str] = mapped_column(
NullFilteredString,
primary_key=True,
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
source_entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
target_entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
definition: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
comment="Whether this relationship type represents a definition",
)
clustering: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Clustering information for this relationship type",
)
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
occurrences: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
# Relationships to EntityType
source_type: Mapped["KGEntityType"] = relationship(
"KGEntityType",
foreign_keys=[source_entity_type_id_name],
backref="source_relationship_type_staging",
)
target_type: Mapped["KGEntityType"] = relationship(
"KGEntityType",
foreign_keys=[target_entity_type_id_name],
backref="target_relationship_type_staging",
)
class KGEntity(Base):
__tablename__ = "kg_entity"
# Primary identifier
id_name: Mapped[str] = mapped_column(
NullFilteredString, primary_key=True, index=True
)
# Basic entity information
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
attributes: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Attributes for this entity",
)
document_id: Mapped[str | None] = mapped_column(
NullFilteredString, nullable=True, index=True
)
alternative_names: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Reference to KGEntityType
entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
# Relationship to KGEntityType
entity_type: Mapped["KGEntityType"] = relationship("KGEntityType", backref="entity")
description: Mapped[str | None] = mapped_column(String, nullable=True)
keywords: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
occurrences: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Access control
acl: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Boosts - using JSON for flexibility
boosts: Mapped[dict] = mapped_column(postgresql.JSONB, nullable=False, default=dict)
event_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
comment="Time of the event being processed",
)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
__table_args__ = (
# Fixed column names in indexes
Index("ix_entity_type_acl", entity_type_id_name, acl),
Index("ix_entity_name_search", name, entity_type_id_name),
)
class KGEntityExtractionStaging(Base):
__tablename__ = "kg_entity_extraction_staging"
# Primary identifier
id_name: Mapped[str] = mapped_column(
NullFilteredString, primary_key=True, index=True
)
# Basic entity information
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
attributes: Mapped[dict] = mapped_column(
postgresql.JSONB,
nullable=False,
default=dict,
server_default="{}",
comment="Attributes for this entity",
)
document_id: Mapped[str | None] = mapped_column(
NullFilteredString, nullable=True, index=True
)
alternative_names: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Reference to KGEntityType
entity_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
# Relationship to KGEntityType
entity_type: Mapped["KGEntityType"] = relationship(
"KGEntityType", backref="entity_staging"
)
description: Mapped[str | None] = mapped_column(String, nullable=True)
keywords: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
occurrences: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Access control
acl: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Boosts - using JSON for flexibility
boosts: Mapped[dict] = mapped_column(postgresql.JSONB, nullable=False, default=dict)
event_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
comment="Time of the event being processed",
)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
__table_args__ = (
# Fixed column names in indexes
Index("ix_entity_type_acl", entity_type_id_name, acl),
Index("ix_entity_name_search", name, entity_type_id_name),
)
class KGRelationship(Base):
__tablename__ = "kg_relationship"
# Primary identifier - now part of composite key
id_name: Mapped[str] = mapped_column(NullFilteredString, index=True)
source_document: Mapped[str | None] = mapped_column(
NullFilteredString, ForeignKey("document.id"), nullable=True, index=True
)
# Source and target nodes (foreign keys to Entity table)
source_node: Mapped[str] = mapped_column(
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
)
target_node: Mapped[str] = mapped_column(
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
)
source_node_type: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
target_node_type: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
# Relationship type
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
# Add new relationship type reference
relationship_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_relationship_type.id_name"),
nullable=False,
index=True,
)
# Add the SQLAlchemy relationship property
relationship_type: Mapped["KGRelationshipType"] = relationship(
"KGRelationshipType", backref="relationship"
)
occurrences: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
# Relationships to Entity table
source: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[source_node])
target: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[target_node])
document: Mapped["Document"] = relationship(
"Document", foreign_keys=[source_document]
)
__table_args__ = (
# Composite primary key
PrimaryKeyConstraint("id_name", "source_document"),
# Index for querying relationships by type
Index("ix_kg_relationship_type", type),
# Composite index for source/target queries
Index("ix_kg_relationship_nodes", source_node, target_node),
# Ensure unique relationships between nodes of a specific type
UniqueConstraint(
"source_node",
"target_node",
"type",
name="uq_kg_relationship_source_target_type",
),
)
class KGRelationshipExtractionStaging(Base):
__tablename__ = "kg_relationship_extraction_staging"
# Primary identifier - now part of composite key
id_name: Mapped[str] = mapped_column(NullFilteredString, index=True)
source_document: Mapped[str | None] = mapped_column(
NullFilteredString, ForeignKey("document.id"), nullable=True, index=True
)
# Source and target nodes (foreign keys to Entity table)
source_node: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_extraction_staging.id_name"),
nullable=False,
index=True,
)
target_node: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_extraction_staging.id_name"),
nullable=False,
index=True,
)
source_node_type: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
target_node_type: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_entity_type.id_name"),
nullable=False,
index=True,
)
# Relationship type
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
# Add new relationship type reference
relationship_type_id_name: Mapped[str] = mapped_column(
NullFilteredString,
ForeignKey("kg_relationship_type_extraction_staging.id_name"),
nullable=False,
index=True,
)
# Add the SQLAlchemy relationship property
relationship_type: Mapped["KGRelationshipTypeExtractionStaging"] = relationship(
"KGRelationshipTypeExtractionStaging", backref="relationship_staging"
)
occurrences: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
# Relationships to Entity table
source: Mapped["KGEntityExtractionStaging"] = relationship(
"KGEntityExtractionStaging", foreign_keys=[source_node]
)
target: Mapped["KGEntityExtractionStaging"] = relationship(
"KGEntityExtractionStaging", foreign_keys=[target_node]
)
document: Mapped["Document"] = relationship(
"Document", foreign_keys=[source_document]
)
__table_args__ = (
# Composite primary key
PrimaryKeyConstraint("id_name", "source_document"),
# Index for querying relationships by type
Index("ix_kg_relationship_type", type),
# Composite index for source/target queries
Index("ix_kg_relationship_nodes", source_node, target_node),
# Ensure unique relationships between nodes of a specific type
UniqueConstraint(
"source_node",
"target_node",
"type",
name="uq_kg_relationship_source_target_type",
),
)
class KGTerm(Base):
__tablename__ = "kg_term"
# Make id_term the primary key
id_term: Mapped[str] = mapped_column(
NullFilteredString, primary_key=True, nullable=False, index=True
)
# List of entity types this term applies to
entity_types: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), nullable=False, default=list
)
# Tracking fields
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
__table_args__ = (
# Index for searching terms with specific entity types
Index("ix_search_term_entities", entity_types),
# Index for term lookups
Index("ix_search_term_term", id_term),
)
class ChunkStats(Base):
__tablename__ = "chunk_stats"
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
# this should correspond to the ID of the document
# (as is passed around in Onyx)
# (as is passed around in Onyx)x
id: Mapped[str] = mapped_column(
NullFilteredString,
primary_key=True,
@@ -691,6 +1280,14 @@ class Connector(Base):
indexing_start: Mapped[datetime.datetime | None] = mapped_column(
DateTime, nullable=True
)
kg_processing_enabled: Mapped[bool] = mapped_column(
Boolean,
nullable=False,
default=False,
comment="Whether this connector should extract knowledge graph entities",
)
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
time_created: Mapped[datetime.datetime] = mapped_column(

View File

@@ -0,0 +1,523 @@
from typing import cast
from typing import List
from typing import Union
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session
from onyx.db.models import KGEntity
from onyx.db.models import KGEntityExtractionStaging
from onyx.db.models import KGRelationship
from onyx.db.models import KGRelationshipExtractionStaging
from onyx.db.models import KGRelationshipType
from onyx.db.models import KGRelationshipTypeExtractionStaging
from onyx.db.models import KGStage
from onyx.kg.utils.formatting_utils import format_entity
from onyx.kg.utils.formatting_utils import format_relationship
from onyx.kg.utils.formatting_utils import generate_relationship_type
def add_relationship(
db_session: Session,
kg_stage: KGStage,
relationship_id_name: str,
source_document_id: str,
occurrences: int | None = None,
) -> Union["KGRelationship", "KGRelationshipExtractionStaging"]:
"""
Add a relationship between two entities to the database.
Args:
db_session: SQLAlchemy database session
relationship_type: Type of relationship
source_document_id: ID of the source document
occurrences: Optional count of similar relationships clustered together
Returns:
The created KGRelationship object
Raises:
sqlalchemy.exc.IntegrityError: If the relationship already exists or entities don't exist
"""
# Generate a unique ID for the relationship
(
source_entity_id_name,
relationship_string,
target_entity_id_name,
) = relationship_id_name.split("__")
source_entity_id_name = format_entity(source_entity_id_name)
source_entity_type = source_entity_id_name.split(":")[0]
target_entity_id_name = format_entity(target_entity_id_name)
target_entity_type = target_entity_id_name.split(":")[0]
relationship_id_name = format_relationship(relationship_id_name)
relationship_type = generate_relationship_type(relationship_id_name)
relationship_data = {
"id_name": relationship_id_name,
"source_node": source_entity_id_name,
"target_node": target_entity_id_name,
"source_node_type": source_entity_type,
"target_node_type": target_entity_type,
"type": relationship_string.lower(),
"relationship_type_id_name": relationship_type,
"source_document": source_document_id,
"occurrences": occurrences or 1,
}
relationship: KGRelationship | KGRelationshipExtractionStaging
if kg_stage == KGStage.EXTRACTED:
relationship = KGRelationshipExtractionStaging(**relationship_data)
elif kg_stage == KGStage.NORMALIZED:
relationship = KGRelationship(**relationship_data)
else:
raise ValueError(f"Invalid kg_stage: {kg_stage}")
# Use on_conflict_do_update to handle conflicts
stmt = (
postgresql.insert(type(relationship))
.values(**relationship_data)
.on_conflict_do_update(
constraint=(
"kg_relationship_pkey"
if kg_stage == KGStage.NORMALIZED
else "kg_relationship_extraction_staging_pkey"
),
set_={
"occurrences": int(str(relationship_data["occurrences"] or 0))
+ (occurrences or 1),
"time_updated": func.now(),
},
)
)
db_session.execute(stmt)
db_session.flush() # Flush to get any DB errors early
# Fetch the updated/inserted record
result: Union[KGRelationship, KGRelationshipExtractionStaging, None] = None
if kg_stage == KGStage.EXTRACTED:
result = (
db_session.query(KGRelationshipExtractionStaging)
.filter_by(id_name=relationship_id_name, source_document=source_document_id)
.first()
)
else:
result = (
db_session.query(KGRelationship)
.filter_by(id_name=relationship_id_name, source_document=source_document_id)
.first()
)
if result is None:
raise ValueError(
f"Failed to create or update relationship with id_name: {relationship_id_name}"
)
return result
def add_or_increment_relationship(
db_session: Session,
kg_stage: KGStage,
relationship_id_name: str,
source_document_id: str,
new_occurrences: int = 1,
) -> KGRelationship | KGRelationshipExtractionStaging:
"""
Add a relationship between two entities to the database if it doesn't exist,
or increment its occurrences by 1 if it already exists.
Args:
db_session: SQLAlchemy database session
relationship_id_name: The ID name of the relationship in format "source__relationship__target"
source_document_id: ID of the source document
Returns:
The created or updated KGRelationship object
Raises:
sqlalchemy.exc.IntegrityError: If there's an error with the database operation
"""
# Format the relationship_id_name
relationship_id_name = format_relationship(relationship_id_name)
_KGTable: type[KGRelationship] | type[KGRelationshipExtractionStaging]
if kg_stage == KGStage.EXTRACTED:
_KGTable = KGRelationshipExtractionStaging
elif kg_stage == KGStage.NORMALIZED:
_KGTable = KGRelationship
else:
raise ValueError(f"Invalid kg_stage: {kg_stage}")
# Check if the relationship already exists
existing_relationship = (
db_session.query(_KGTable)
.filter(_KGTable.id_name == relationship_id_name)
.filter(_KGTable.source_document == source_document_id)
.first()
)
if existing_relationship:
# If it exists, increment the occurrences
existing_relationship = cast(
KGRelationship | KGRelationshipExtractionStaging, existing_relationship
)
existing_relationship.occurrences = (
existing_relationship.occurrences or 0
) + new_occurrences
db_session.flush()
return existing_relationship
else:
# If it doesn't exist, add it with occurrences=1
db_session.flush()
return add_relationship(
db_session,
KGStage(kg_stage),
relationship_id_name,
source_document_id,
occurrences=new_occurrences,
)
def add_relationship_type(
db_session: Session,
kg_stage: KGStage,
source_entity_type: str,
relationship_type: str,
target_entity_type: str,
definition: bool = False,
extraction_count: int = 0,
) -> str:
"""
Add a new relationship type to the database.
Args:
db_session: SQLAlchemy session
source_entity_type: Type of the source entity
relationship_type: Type of relationship
target_entity_type: Type of the target entity
definition: Whether this relationship type represents a definition (default False)
Returns:
The created KGRelationshipType object
Raises:
sqlalchemy.exc.IntegrityError: If the relationship type already exists
"""
id_name = f"{source_entity_type.upper()}__{relationship_type}__{target_entity_type.upper()}"
# Create new relationship type
relationship_data = {
"id_name": id_name,
"name": relationship_type,
"source_entity_type_id_name": source_entity_type.upper(),
"target_entity_type_id_name": target_entity_type.upper(),
"definition": definition,
"occurrences": extraction_count,
"type": relationship_type, # Using the relationship_type as the type
"active": True, # Setting as active by default
}
rel_type: KGRelationshipType | KGRelationshipTypeExtractionStaging
if kg_stage == KGStage.EXTRACTED:
rel_type = KGRelationshipTypeExtractionStaging(**relationship_data)
elif kg_stage == KGStage.NORMALIZED:
rel_type = KGRelationshipType(**relationship_data)
else:
raise ValueError(f"Invalid kg_stage: {kg_stage}")
# Use on_conflict_do_update to handle conflicts
stmt = (
postgresql.insert(type(rel_type))
.values(**relationship_data)
.on_conflict_do_update(
index_elements=["id_name"],
set_={
"name": relationship_data["name"],
"source_entity_type_id_name": relationship_data[
"source_entity_type_id_name"
],
"target_entity_type_id_name": relationship_data[
"target_entity_type_id_name"
],
"definition": relationship_data["definition"],
"occurrences": int(str(relationship_data["occurrences"] or 0))
+ extraction_count,
"type": relationship_data["type"],
"active": relationship_data["active"],
"time_updated": func.now(),
},
)
)
db_session.execute(stmt)
db_session.flush() # Flush to get any DB errors early
return id_name
def get_all_relationship_types(
db_session: Session, kg_stage: str
) -> list["KGRelationshipType"] | list["KGRelationshipTypeExtractionStaging"]:
"""
Retrieve all relationship types from the database.
Args:
db_session: SQLAlchemy database session
Returns:
List of KGRelationshipType or KGRelationshipTypeExtractionStaging objects
"""
if kg_stage == KGStage.EXTRACTED:
return db_session.query(KGRelationshipTypeExtractionStaging).all()
elif kg_stage == KGStage.NORMALIZED:
return db_session.query(KGRelationshipType).all()
else:
raise ValueError(f"Invalid kg_stage: {kg_stage}")
def get_all_relationships(
db_session: Session, kg_stage: KGStage
) -> list["KGRelationship"] | list["KGRelationshipExtractionStaging"]:
"""
Retrieve all relationships from the database.
Args:
db_session: SQLAlchemy database session
Returns:
List of KGRelationship objects
"""
if kg_stage == KGStage.EXTRACTED:
return db_session.query(KGRelationshipExtractionStaging).all()
elif kg_stage == KGStage.NORMALIZED:
return db_session.query(KGRelationship).all()
else:
raise ValueError(f"Invalid kg_stage: {kg_stage}")
def delete_relationships_by_id_names(
db_session: Session, id_names: list[str], kg_stage: KGStage
) -> int:
"""
Delete relationships from the database based on a list of id_names.
Args:
db_session: SQLAlchemy database session
id_names: List of relationship id_names to delete
Returns:
Number of relationships deleted
Raises:
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
"""
deleted_count = 0
if kg_stage == KGStage.EXTRACTED:
deleted_count = (
db_session.query(KGRelationshipExtractionStaging)
.filter(KGRelationshipExtractionStaging.id_name.in_(id_names))
.delete(synchronize_session=False)
)
elif kg_stage == KGStage.NORMALIZED:
deleted_count = (
db_session.query(KGRelationship)
.filter(KGRelationship.id_name.in_(id_names))
.delete(synchronize_session=False)
)
db_session.flush() # Flush to ensure deletion is processed
return deleted_count
def delete_relationship_types_by_id_names(
db_session: Session, id_names: list[str], kg_stage: KGStage
) -> int:
"""
Delete relationship types from the database based on a list of id_names.
Args:
db_session: SQLAlchemy database session
id_names: List of relationship type id_names to delete
Returns:
Number of relationship types deleted
Raises:
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
"""
deleted_count = 0
if kg_stage == KGStage.EXTRACTED:
deleted_count = (
db_session.query(KGRelationshipTypeExtractionStaging)
.filter(KGRelationshipTypeExtractionStaging.id_name.in_(id_names))
.delete(synchronize_session=False)
)
elif kg_stage == KGStage.NORMALIZED:
deleted_count = (
db_session.query(KGRelationshipType)
.filter(KGRelationshipType.id_name.in_(id_names))
.delete(synchronize_session=False)
)
db_session.flush() # Flush to ensure deletion is processed
return deleted_count
def get_relationships_for_entity_type_pairs(
db_session: Session, entity_type_pairs: list[tuple[str, str]]
) -> list["KGRelationshipType"]:
"""
Get relationship types from the database based on a list of entity type pairs.
Args:
db_session: SQLAlchemy database session
entity_type_pairs: List of tuples where each tuple contains (source_entity_type, target_entity_type)
Returns:
List of KGRelationshipType objects where source and target types match the provided pairs
"""
conditions = [
(
(KGRelationshipType.source_entity_type_id_name == source_type)
& (KGRelationshipType.target_entity_type_id_name == target_type)
)
for source_type, target_type in entity_type_pairs
]
return db_session.query(KGRelationshipType).filter(or_(*conditions)).all()
def get_allowed_relationship_type_pairs(
db_session: Session, entities: list[str]
) -> list[str]:
"""
Get the allowed relationship pairs for the given entities.
Args:
db_session: SQLAlchemy database session
entities: List of entity type ID names to filter by
Returns:
List of id_names from KGRelationshipType where both source and target entity types
are in the provided entities list
"""
entity_types = list(set([entity.split(":")[0] for entity in entities]))
return [
row[0]
for row in (
db_session.query(KGRelationshipType.id_name)
.filter(KGRelationshipType.source_entity_type_id_name.in_(entity_types))
.filter(KGRelationshipType.target_entity_type_id_name.in_(entity_types))
.distinct()
.all()
)
]
def get_relationships_of_entity(db_session: Session, entity_id: str) -> List[str]:
"""Get all relationship ID names where the given entity is either the source or target node.
Args:
db_session: SQLAlchemy session
entity_id: ID of the entity to find relationships for
Returns:
List of relationship ID names where the entity is either source or target
"""
return [
row[0]
for row in (
db_session.query(KGRelationship.id_name)
.filter(
or_(
KGRelationship.source_node == entity_id,
KGRelationship.target_node == entity_id,
)
)
.all()
)
]
def get_relationship_types_of_entity_types(
db_session: Session, entity_types_id: str
) -> List[str]:
"""Get all relationship ID names where the given entity is either the source or target node.
Args:
db_session: SQLAlchemy session
entity_types_id: ID of the entity to find relationships for
Returns:
List of relationship ID names where the entity is either source or target
"""
if entity_types_id.endswith(":*"):
entity_types_id = entity_types_id[:-2]
return [
row[0]
for row in (
db_session.query(KGRelationshipType.id_name)
.filter(
or_(
KGRelationshipType.source_entity_type_id_name == entity_types_id,
KGRelationshipType.target_entity_type_id_name == entity_types_id,
)
)
.all()
)
]
def delete_document_references_from_kg(db_session: Session, document_id: str) -> None:
# Delete relationships from normalized stage
db_session.query(KGRelationship).filter(
KGRelationship.source_document == document_id
).delete(synchronize_session=False)
# Delete relationships from extraction staging
db_session.query(KGRelationshipExtractionStaging).filter(
KGRelationshipExtractionStaging.source_document == document_id
).delete(synchronize_session=False)
# Delete entities from normalized stage
db_session.query(KGEntity).filter(KGEntity.document_id == document_id).delete(
synchronize_session=False
)
# Delete entities from extraction staging
db_session.query(KGEntityExtractionStaging).filter(
KGEntityExtractionStaging.document_id == document_id
).delete(synchronize_session=False)
db_session.flush()
def delete_from_kg_relationships_extraction_staging__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""Delete relationships from the extraction staging table."""
db_session.query(KGRelationshipExtractionStaging).filter(
KGRelationshipExtractionStaging.source_document.in_(document_ids)
).delete(synchronize_session=False)
def delete_from_kg_relationships__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""Delete relationships from the normalized table."""
db_session.query(KGRelationship).filter(
KGRelationship.source_document.in_(document_ids)
).delete(synchronize_session=False)

View File

@@ -91,6 +91,25 @@ schema {{ schema_name }} {
indexing: attribute
}
# Separate array fields for knowledge graph data
field kg_entities type weightedset<string> {
rank: filter
indexing: summary | attribute
attribute: fast-search
}
field kg_relationships type weightedset<string> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
field kg_terms type weightedset<string> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
# Needs to have a separate Attribute list for efficient filtering
field metadata_list type array<string> {
indexing: summary | attribute

View File

@@ -166,18 +166,19 @@ def _get_chunks_via_visit_api(
# build the list of fields to retrieve
field_set_list = (
None
if not field_names
else [f"{index_name}:{field_name}" for field_name in field_names]
[f"{field_name}" for field_name in field_names] if field_names else []
)
acl_fieldset_entry = f"{index_name}:{ACCESS_CONTROL_LIST}"
acl_fieldset_entry = f"{ACCESS_CONTROL_LIST}"
if (
field_set_list
and filters.access_control_list
and acl_fieldset_entry not in field_set_list
):
field_set_list.append(acl_fieldset_entry)
field_set = ",".join(field_set_list) if field_set_list else None
if field_set_list:
field_set = f"{index_name}:" + ",".join(field_set_list)
else:
field_set = None
# build filters
selection = f"{index_name}.document_id=='{chunk_request.document_id}'"

View File

@@ -18,6 +18,7 @@ from uuid import UUID
import httpx # type: ignore
import jinja2
import requests # type: ignore
from pydantic import BaseModel
from retry import retry
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
@@ -30,6 +31,7 @@ from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.document_index_utils import get_document_chunk_ids
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentInsertionRecord
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
@@ -86,6 +88,17 @@ httpx_logger = logging.getLogger("httpx")
httpx_logger.setLevel(logging.WARNING)
def update_kg_type_dict(
dict_to_update: dict[str, dict], kg_type: str, value_set: set[str]
) -> dict[str, dict]:
if "fields" not in dict_to_update:
dict_to_update["fields"] = {}
dict_to_update["fields"][kg_type] = {
"assign": {kg_type_object: 1 for kg_type_object in value_set}
}
return dict_to_update
@dataclass
class _VespaUpdateRequest:
document_id: str
@@ -93,6 +106,39 @@ class _VespaUpdateRequest:
update_request: dict[str, dict]
class KGVespaChunkUpdateRequest(BaseModel):
document_id: str
chunk_id: int
url: str
update_request: dict[str, dict]
class KGUChunkUpdateRequest(BaseModel):
"""
Update KG fields for a document
"""
document_id: str
chunk_id: int
core_entity: str
entities: set[str] | None = None
relationships: set[str] | None = None
terms: set[str] | None = None
converted_attributes: set[str] | None = None
attributes: dict[str, str | list[str]] | None = None
class KGUDocumentUpdateRequest(BaseModel):
"""
Update KG fields for a document
"""
document_id: str
entities: set[str]
relationships: set[str]
terms: set[str]
def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
@@ -501,6 +547,51 @@ class VespaIndex(DocumentIndex):
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
raise requests.HTTPError(failure_msg) from e
@classmethod
def _apply_kg_chunk_updates_batched(
cls,
updates: list[KGVespaChunkUpdateRequest],
httpx_client: httpx.Client,
batch_size: int = BATCH_SIZE,
) -> None:
"""Runs a batch of updates in parallel via the ThreadPoolExecutor."""
def _kg_update_chunk(
update: KGVespaChunkUpdateRequest, http_client: httpx.Client
) -> httpx.Response:
# logger.debug(
# f"Updating KG with request to {update.url} with body {update.update_request}"
# )
return http_client.put(
update.url,
headers={"Content-Type": "application/json"},
json=update.update_request,
)
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficient for
# indexing / updates / deletes since we have to make a large volume of requests.
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
httpx_client as http_client,
):
for update_batch in batch_generator(updates, batch_size):
future_to_document_id = {
executor.submit(
_kg_update_chunk,
update,
http_client,
): update.document_id
for update in update_batch
}
for future in concurrent.futures.as_completed(future_to_document_id):
res = future.result()
try:
res.raise_for_status()
except requests.HTTPError as e:
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
raise requests.HTTPError(failure_msg) from e
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
@@ -584,6 +675,77 @@ class VespaIndex(DocumentIndex):
time.monotonic() - update_start,
)
def kg_chunk_updates(
self, kg_update_requests: list[KGUChunkUpdateRequest], tenant_id: str
) -> None:
processed_updates_requests: list[KGVespaChunkUpdateRequest] = []
logger.debug(f"Updating {len(kg_update_requests)} documents in Vespa")
update_start = time.monotonic()
# Build the _VespaUpdateRequest objects
for kg_update_request in kg_update_requests:
kg_update_dict: dict[str, dict] = {"fields": {}}
implied_entities = set()
if kg_update_request.relationships is not None:
for kg_relationship in kg_update_request.relationships:
kg_relationship_split = kg_relationship.split("__")
if len(kg_relationship_split) == 3:
implied_entities.add(kg_relationship_split[0])
implied_entities.add(kg_relationship_split[2])
kg_update_dict = update_kg_type_dict(
kg_update_dict, "kg_relationships", kg_update_request.relationships
)
if kg_update_request.entities is not None or implied_entities:
if kg_update_request.entities is None:
kg_entities = implied_entities
else:
kg_entities = set(kg_update_request.entities)
kg_entities.update(implied_entities)
kg_update_dict = update_kg_type_dict(
kg_update_dict, "kg_entities", kg_entities
)
if kg_update_request.terms is not None:
kg_update_dict = update_kg_type_dict(
kg_update_dict, "kg_terms", kg_update_request.terms
)
if not kg_update_dict["fields"]:
logger.error("Update request received but nothing to update")
continue
doc_chunk_id = get_uuid_from_chunk_info(
document_id=kg_update_request.document_id,
chunk_id=kg_update_request.chunk_id,
tenant_id=tenant_id,
large_chunk_id=None,
)
processed_updates_requests.append(
KGVespaChunkUpdateRequest(
document_id=kg_update_request.document_id,
chunk_id=kg_update_request.chunk_id,
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
update_request=kg_update_dict,
)
)
with self.httpx_client_context as httpx_client:
self._apply_kg_chunk_updates_batched(
processed_updates_requests, httpx_client
)
logger.debug(
"Finished updating Vespa documents in %.2f seconds",
time.monotonic() - update_start,
)
@retry(
tries=3,
delay=1,

View File

@@ -0,0 +1,80 @@
from pydantic import BaseModel
from retry import retry
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
from onyx.document_index.vespa.index import IndexFilters
from onyx.document_index.vespa.index import KGUChunkUpdateRequest
from onyx.document_index.vespa.index import VespaIndex
from onyx.utils.logger import setup_logger
# from backend.onyx.chat.process_message import get_inference_chunks
# from backend.onyx.document_index.vespa.index import VespaIndex
logger = setup_logger()
class KGChunkInfo(BaseModel):
kg_relationships: dict[str, int]
kg_entities: dict[str, int]
kg_terms: dict[str, int]
@retry(tries=3, delay=1, backoff=2)
def get_document_kg_info(
document_id: str,
index_name: str,
filters: IndexFilters | None = None,
) -> dict | None:
"""
Retrieve the kg_info attribute from a Vespa document by its document_id.
Args:
document_id: The unique identifier of the document.
index_name: The name of the Vespa index to query.
filters: Optional access control filters to apply.
Returns:
The kg_info dictionary if found, None otherwise.
"""
# Use the existing visit API infrastructure
kg_doc_info: dict[int, KGChunkInfo] = {}
document_chunks = _get_chunks_via_visit_api(
chunk_request=VespaChunkRequest(document_id=document_id),
index_name=index_name,
filters=filters or IndexFilters(access_control_list=None),
field_names=["kg_relationships", "kg_entities", "kg_terms"],
get_large_chunks=False,
)
for chunk_id, document_chunk in enumerate(document_chunks):
kg_chunk_info = KGChunkInfo(
kg_relationships=document_chunk["fields"].get("kg_relationships", {}),
kg_entities=document_chunk["fields"].get("kg_entities", {}),
kg_terms=document_chunk["fields"].get("kg_terms", {}),
)
kg_doc_info[chunk_id] = kg_chunk_info # TODO: check the chunk id is correct!
return kg_doc_info
@retry(tries=3, delay=1, backoff=2)
def update_kg_chunks_vespa_info(
kg_update_requests: list[KGUChunkUpdateRequest],
index_name: str,
tenant_id: str,
) -> None:
""" """
# Use the existing visit API infrastructure
vespa_index = VespaIndex(
index_name=index_name,
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=False,
multitenant=False,
httpx_client=None,
)
vespa_index.kg_chunk_updates(
kg_update_requests=kg_update_requests, tenant_id=tenant_id
)

View File

@@ -54,6 +54,47 @@ def build_vespa_filters(
return result
def _build_kg_filter(
kg_entities: list[str] | None,
kg_relationships: list[str] | None,
kg_terms: list[str] | None,
) -> str:
if not kg_entities and not kg_relationships and not kg_terms:
return ""
filter_parts = []
# Process each filter type using the same pattern
for filter_type, values in [
("kg_entities", kg_entities),
("kg_relationships", kg_relationships),
("kg_terms", kg_terms),
]:
if values:
filter_parts.append(
" and ".join(f'({filter_type} contains "{val}") ' for val in values)
)
return f"({' and '.join(filter_parts)}) and "
def _build_kg_source_filters(
kg_sources: list[str] | None,
) -> str:
if not kg_sources:
return ""
source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources]
return f"({' or '.join(source_phrases)}) and "
def _build_kg_chunk_id_zero_only_filter(
kg_chunk_id_zero_only: bool,
) -> str:
if not kg_chunk_id_zero_only:
return ""
return "(chunk_id = 0 ) and "
def _build_time_filter(
cutoff: datetime | None,
untimed_doc_cutoff: timedelta = timedelta(days=92),
@@ -106,6 +147,19 @@ def build_vespa_filters(
# Time filter
filter_str += _build_time_filter(filters.time_cutoff)
# Knowledge Graph Filters
filter_str += _build_kg_filter(
kg_entities=filters.kg_entities,
kg_relationships=filters.kg_relationships,
kg_terms=filters.kg_terms,
)
filter_str += _build_kg_source_filters(filters.kg_sources)
filter_str += _build_kg_chunk_id_zero_only_filter(
filters.kg_chunk_id_zero_only or False
)
# Trim trailing " and "
if remove_trailing_and and filter_str.endswith(" and "):
filter_str = filter_str[:-5]

View File

@@ -0,0 +1,178 @@
from typing import Set
from onyx.db.document import update_document_kg_info
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import add_entity
from onyx.db.entities import delete_entities_by_id_names
from onyx.db.entities import get_entities_by_grounding
from onyx.db.relationships import add_relationship
from onyx.db.relationships import add_relationship_type
from onyx.db.relationships import delete_relationship_types_by_id_names
from onyx.db.relationships import delete_relationships_by_id_names
from onyx.db.relationships import get_all_relationship_types
from onyx.db.relationships import get_all_relationships
from onyx.kg.models import KGGroundingType
from onyx.kg.models import KGStage
from onyx.utils.logger import setup_logger
logger = setup_logger()
def kg_incremental_cluster_updates(
tenant_id: str, index_name: str, processing_chunk_batch_size: int = 8
) -> None:
"""
Here we will cluster the extractions based on their cluster frameworks.
Initially, this will only focus on grounded entities with pre-determined
relationships, so clustering is actually not yet required.
The primary purpose of this function is to populate the actual KG tables
from the temp_extraction tables.
This will change with deep extraction, where grounded-sourceless entities
can be extracted and then need to be clustered.
"""
logger.info(f"Starting kg clustering for tenant {tenant_id}")
## Retrieval
source_documents_w_successful_transfers: Set[str] = set()
source_documents_w_failed_transfers: Set[str] = set()
# get objects that are now in the Staging tables
with get_session_with_current_tenant() as db_session:
relationship_types = get_all_relationship_types(
db_session, kg_stage=KGStage.EXTRACTED
)
relationships = get_all_relationships(db_session, kg_stage=KGStage.EXTRACTED)
grounded_entities = get_entities_by_grounding(
db_session, KGStage.EXTRACTED, KGGroundingType.GROUNDED
)
## Clustering
# TODO: we will re-implement the cluster matching logic here
## Database operations
# create the clustered objects - entities
transferred_entities: list[str] = []
for grounded_entity in grounded_entities:
with get_session_with_current_tenant() as db_session:
added_entity = add_entity(
db_session,
KGStage.NORMALIZED,
entity_type=grounded_entity.entity_type_id_name,
name=grounded_entity.name,
occurrences=grounded_entity.occurrences or 1,
document_id=grounded_entity.document_id or None,
attributes=grounded_entity.attributes or None,
)
db_session.commit()
if added_entity:
transferred_entities.append(added_entity.id_name)
transferred_relationship_types: list[str] = []
for relationship_type in relationship_types:
with get_session_with_current_tenant() as db_session:
added_relationship_type_id_name = add_relationship_type(
db_session,
KGStage.NORMALIZED,
source_entity_type=relationship_type.source_entity_type_id_name,
relationship_type=relationship_type.type,
target_entity_type=relationship_type.target_entity_type_id_name,
extraction_count=relationship_type.occurrences or 1,
)
db_session.commit()
transferred_relationship_types.append(added_relationship_type_id_name)
transferred_relationships: list[str] = []
for relationship in relationships:
with get_session_with_current_tenant() as db_session:
try:
added_relationship = add_relationship(
db_session,
KGStage.NORMALIZED,
relationship_id_name=relationship.id_name,
source_document_id=relationship.source_document or "",
occurrences=relationship.occurrences or 1,
)
if relationship.source_document:
source_documents_w_successful_transfers.add(
relationship.source_document
)
db_session.commit()
transferred_relationships.append(added_relationship.id_name)
except Exception as e:
if relationship.source_document:
source_documents_w_failed_transfers.add(
relationship.source_document
)
logger.error(
f"Error transferring relationship {relationship.id_name}: {e}"
)
# TODO: remove the /relationship types & entities that correspond to relationships
# source documents that failed to transfer. I.e, do a proper rollback
# TODO: update Vespa info when clustering/changes are performed
# delete the added objects from the staging tables
try:
with get_session_with_current_tenant() as db_session:
delete_relationships_by_id_names(
db_session, transferred_relationships, kg_stage=KGStage.EXTRACTED
)
db_session.commit()
except Exception as e:
logger.error(f"Error deleting relationships: {e}")
try:
with get_session_with_current_tenant() as db_session:
delete_relationship_types_by_id_names(
db_session, transferred_relationship_types, kg_stage=KGStage.EXTRACTED
)
db_session.commit()
except Exception as e:
logger.error(f"Error deleting relationship types: {e}")
try:
with get_session_with_current_tenant() as db_session:
delete_entities_by_id_names(
db_session, transferred_entities, kg_stage=KGStage.EXTRACTED
)
db_session.commit()
except Exception as e:
logger.error(f"Error deleting entities: {e}")
# Update document kg info
# with get_session_with_current_tenant() as db_session:
# all_kg_extracted_documents_info = get_all_kg_extracted_documents_info(
# db_session
# )
for document_id in source_documents_w_successful_transfers:
# Update the document kg info
with get_session_with_current_tenant() as db_session:
update_document_kg_info(
db_session,
document_id=document_id,
kg_stage=KGStage.NORMALIZED,
)
db_session.commit()

View File

@@ -0,0 +1,638 @@
from collections import defaultdict
from typing import Any
from typing import Dict
from typing import List
from typing import Set
import numpy as np
from sklearn.cluster import SpectralClustering # type: ignore
from thefuzz import fuzz # type: ignore
from onyx.db.document import update_document_kg_info
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import add_entity
from onyx.db.entities import delete_entities_by_id_names
from onyx.db.entities import get_entities_by_grounding
from onyx.db.entity_type import get_determined_grounded_entity_types
from onyx.db.relationships import add_relationship
from onyx.db.relationships import add_relationship_type
from onyx.db.relationships import delete_relationship_types_by_id_names
from onyx.db.relationships import delete_relationships_by_id_names
from onyx.db.relationships import get_all_relationship_types
from onyx.db.relationships import get_all_relationships
from onyx.kg.models import KGGroundingType
from onyx.kg.models import KGStage
from onyx.kg.utils.embeddings import encode_string_batch
from onyx.llm.factory import get_default_llms
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _create_ge_determined_entity_map() -> Dict[str, List[str]]:
"""Create a mapping of entity type ID names to their grounding determination instructions.
Returns:
Dictionary mapping entity type ID names to their list of grounding determination instructions
"""
ge_determined_entity_map: Dict[str, List[str]] = defaultdict(list)
with get_session_with_current_tenant() as db_session:
determined_entities = get_determined_grounded_entity_types(db_session)
for entity_type in determined_entities:
if entity_type.entity_values: # Extra safety check
ge_determined_entity_map[entity_type.id_name] = (
entity_type.entity_values
)
return ge_determined_entity_map
def _cluster_relationships(
relationship_data: List[dict], n_clusters: int = 3, batch_size: int = 12
) -> Dict[int, List[str]]:
"""
Cluster relationships using their embeddings.
Args:
relationship_data: List of dicts with 'name' and 'cluster_count'
n_clusters: Number of clusters to create
batch_size: Size of batches for embedding requests
Returns:
Dictionary mapping cluster IDs to lists of relationship names
"""
# TODO: This is TEMP for the pre-defined relationships.
# if len(relationship_data) < n_clusters:
if len(relationship_data) < n_clusters:
logger.warning(
"Not enough relationships to cluster. Returning each relationship as its own cluster."
)
return {i: [rel["name"]] for (i, rel) in enumerate(relationship_data)}
train_data = []
rel_names = []
# Process relationships in batches
for i in range(0, len(relationship_data), batch_size):
batch = relationship_data[i : i + batch_size]
batch_names = [
rel["name"].replace("_", " ") for rel in batch
] # better for LLM to have spaces between words
# Get embeddings for the entire batch at once
batch_embeddings = encode_string_batch(batch_names)
# Add embeddings and corresponding data
for rel, embedding in zip(batch, batch_embeddings):
count = int(rel["cluster_count"]) or 1
# Add the relationship name 'count' times
for _ in range(count):
train_data.append(embedding)
rel_names.append(rel["name"])
# Convert to numpy arrays
X = np.array(train_data)
# Perform clustering
# clustering = KMeans(n_clusters=n_clusters, random_state=42)
clustering = SpectralClustering(n_clusters=n_clusters, random_state=42)
clusters = clustering.fit_predict(X)
# Group relationship names by cluster
cluster_groups: Dict[int, List[str]] = defaultdict(list)
for rel_name, cluster_id in zip(rel_names, clusters):
if rel_name not in cluster_groups[cluster_id]:
cluster_groups[cluster_id].append(rel_name)
return dict(cluster_groups)
def _cluster_entities(
entity_data: List[dict], n_clusters: int = 3, batch_size: int = 12
) -> Dict[int, List[str]]:
"""
Cluster entities using their embeddings.
Args:
entity_data: List of dicts with 'name' and 'cluster_count'
n_clusters: Number of clusters to create
batch_size: Size of batches for embedding requests
Returns:
Dictionary mapping cluster IDs to lists of entity names
"""
if len(entity_data) < n_clusters:
logger.warning(
"Not enough entities to cluster. Returning each entity as its own cluster."
)
return {
i: [ent["name"] for ent in entity_data] for i in range(len(entity_data))
}
train_data = []
entity_names = []
# Process entities in batches
for i in range(0, len(entity_data), batch_size):
batch = entity_data[i : i + batch_size]
batch_names = [
ent["name"].replace("_", " ") for ent in batch
] # use spaces between words for LLM
# Get embeddings for the entire batch at once
batch_embeddings = encode_string_batch(batch_names)
# Add embeddings and corresponding data
for ent, embedding in zip(batch, batch_embeddings):
count = int(ent["cluster_count"]) or 1
# Add the entity name 'count' times
for _ in range(count):
entity_names.append(ent["name"])
train_data.append(embedding)
# Convert to numpy arrays
X = np.array(train_data)
# Perform clustering
# clustering = KMeans(n_clusters=n_clusters, random_state=42)
clustering = SpectralClustering(n_clusters=n_clusters, random_state=42)
clusters = clustering.fit_predict(X)
# Group entity names by cluster
cluster_groups: Dict[int, List[str]] = defaultdict(list)
for ent_name, cluster_id in zip(entity_names, clusters):
if ent_name not in cluster_groups[cluster_id]:
cluster_groups[cluster_id].append(ent_name)
return dict(cluster_groups)
def _create_relationship_type_mapping(
full_clustering_results: Dict[str, Dict[str, Dict[int, Dict[str, Any]]]],
relationship_mapping: Dict[str, Dict[str, List[dict]]],
) -> tuple[Dict[str, str], Dict[str, int]]:
"""
Create a mapping between original relationship types and their clustered versions.
Args:
full_clustering_results: Clustering results with cluster names
relationship_mapping: Original relationship types organized by source/target
Returns:
Dictionary mapping original relationship type ID to clustered relationship type ID
"""
relationship_type_replacements: Dict[str, str] = {}
reverse_relationship_type_replacements_count: Dict[str, int] = defaultdict(int)
for source_type, target_dict in relationship_mapping.items():
for target_type, rel_types in target_dict.items():
# Get clusters for this source/target pair
clusters = full_clustering_results.get(source_type, {}).get(target_type, {})
for cluster_id, cluster_info in clusters.items():
cluster_name = cluster_info["cluster_name"]
for rel_name in cluster_info["relationships"]:
original_id = f"{source_type}__{rel_name.lower()}__{target_type}"
clustered_id = (
f"{source_type}__{cluster_name.lower()}__{target_type}"
)
relationship_type_replacements[original_id] = clustered_id
reverse_relationship_type_replacements_count[clustered_id] += len(
cluster_info["relationships"]
)
return relationship_type_replacements, reverse_relationship_type_replacements_count
def _create_entity_mapping(
full_entity_clustering_results: Dict[str, Dict[int, Dict[str, Any]]],
entity_mapping: Dict[str, List[dict]],
) -> tuple[Dict[str, str], Dict[str, int]]:
"""
Create a mapping between original entities and their clustered versions.
Args:
full_entity_clustering_results: Clustering results with cluster names
entity_mapping: Original entities organized by entity type
Returns:
Dictionary mapping original entity ID to clustered entity ID
"""
entity_replacements: Dict[str, str] = {}
reverse_entity_replacements_count: Dict[str, int] = defaultdict(int)
for entity_type, clusters in full_entity_clustering_results.items():
for cluster_id, cluster_info in clusters.items():
cluster_name = cluster_info["cluster_name"]
for entity_name in cluster_info["entities"]:
# Skip wildcard entities
if entity_name == "*":
continue
original_id = f"{entity_type}:{entity_name}"
clustered_id = f"{entity_type}:{cluster_name.title()}"
entity_replacements[original_id] = clustered_id
reverse_entity_replacements_count[clustered_id] += len(
cluster_info["entities"]
)
return entity_replacements, reverse_entity_replacements_count
def _create_relationship_mapping(
relationship_type_replacements: Dict[str, str],
reverse_relationship_type_replacements_count: Dict[str, int],
entity_replacements: Dict[str, str],
reverse_entity_replacements_count: Dict[str, int],
relationships: List[
Any
], # This would be List[KGRelationship] but avoiding the import
) -> tuple[Dict[str, str], Dict[str, int]]:
"""
Create a mapping between original relationships and their clustered versions,
taking into account both clustered relationship types and clustered entities.
Args:
relationship_type_replacements: Mapping of original to clustered relationship type IDs
entity_replacements: Mapping of original to clustered entity IDs
relationships: List of relationships from the database
Returns:
Dictionary mapping original relationship ID to clustered relationship ID
"""
relationship_replacements: Dict[str, str] = {}
reverse_relationship_replacements_count: Dict[str, int] = defaultdict(int)
for rel in relationships:
# Skip if source or target is a wildcard
# Get the clustered entities (if they exist)
source_node = entity_replacements.get(rel.source_node, rel.source_node)
target_node = entity_replacements.get(rel.target_node, rel.target_node)
rel.source_document
# Create the relationship type ID
source_type = rel.source_node.split(":")[0]
target_type = rel.target_node.split(":")[0]
rel_type_id = f"{source_type}__{rel.type.lower()}__{target_type}"
# Get the clustered relationship type (if it exists)
clustered_rel_type_id = relationship_type_replacements.get(
rel_type_id, rel_type_id
)
# Extract the relationship name from the clustered type ID
_, rel_name, _ = clustered_rel_type_id.split("__")
# Create the original and clustered relationship IDs
original_id = f"{rel.source_node}__{rel.type.lower()}__{rel.target_node}"
clustered_id = f"{source_node}__{rel_name}__{target_node}"
relationship_replacements[original_id] = clustered_id
reverse_relationship_replacements_count[clustered_id] += rel.occurrences or 1
return relationship_replacements, reverse_relationship_replacements_count
def _match_ungrounded_ge_entities(
ungrounded_ge_entities: Dict[str, List[str]],
grounded_ge_entities: Dict[str, List[str]],
fuzzy_match_threshold: int = 80,
) -> Dict[str, Dict[str, str]]:
"""
Create a mapping for ungrounded entities by matching them to grounded entities
or previously processed ungrounded entities. First checks for containment relationships,
then falls back to fuzzy matching if no containment is found.
Args:
ungrounded_ge_entities: Dictionary mapping entity types to lists of ungrounded entity names
grounded_ge_entities: Dictionary mapping entity types to lists of grounded entity names
fuzzy_match_threshold: Threshold for fuzzy matching (0-100)
Returns:
Dictionary mapping entity types to dictionaries of {original_entity: matched_entity}
"""
entity_match_mapping: Dict[str, Dict[str, str]] = defaultdict(dict)
processed_entities: Dict[str, Set[str]] = defaultdict(set)
# For each entity type
for entity_type, ungrounded_entities_list in ungrounded_ge_entities.items():
grounded_list = grounded_ge_entities.get(entity_type, [])
# Process each ungrounded entity
for ungrounded_entity in ungrounded_entities_list:
if ungrounded_entity == "*":
continue
best_match = None
# First check if ungrounded entity is contained in or contains any grounded entities
for grounded_entity in grounded_list:
if (
ungrounded_entity.lower() in grounded_entity.lower()
or grounded_entity.lower() in ungrounded_entity.lower()
):
best_match = grounded_entity
break
# If no containment match with grounded entities, check previously processed ungrounded entities
if not best_match:
for processed_entity in processed_entities[entity_type]:
if (
ungrounded_entity.lower() in processed_entity.lower()
or processed_entity.lower() in ungrounded_entity.lower()
):
best_match = processed_entity
break
# If still no match, fall back to fuzzy matching
if not best_match:
best_score = 0
# Try fuzzy matching with grounded entities
for grounded_entity in grounded_list:
score = fuzz.ratio(
ungrounded_entity.lower(), grounded_entity.lower()
)
if score > fuzzy_match_threshold and score > best_score:
best_match = grounded_entity
best_score = score
# Try fuzzy matching with previously processed ungrounded entities
if not best_match:
for processed_entity in processed_entities[entity_type]:
score = fuzz.ratio(
ungrounded_entity.lower(), processed_entity.lower()
)
if score > fuzzy_match_threshold and score > best_score:
best_match = processed_entity
best_score = score
# Record the mapping
if best_match:
entity_match_mapping[entity_type][ungrounded_entity] = best_match
else:
# No match found, this becomes a new unique entity
entity_match_mapping[entity_type][ungrounded_entity] = ungrounded_entity
processed_entities[entity_type].add(ungrounded_entity)
# Log the results
logger.info("Entity matching results:")
for entity_type, mappings in entity_match_mapping.items():
logger.info(f"\nEntity type: {entity_type}")
for original, matched in mappings.items():
if original != matched:
logger.info(f" Mapped: {original} -> {matched}")
else:
logger.info(f" New unique entity: {original}")
return entity_match_mapping
def _match_determined_ge_entities(
determined_ge_entity_map: Dict[str, List[str]],
determined_ge_entities_by_type: Dict[str, List[str]],
fuzzy_match_threshold: int = 80,
) -> Dict[str, Dict[str, str]]:
"""
Create a mapping for determined entities by matching them to grounded entities
or previously processed ungrounded entities. First checks for containment relationships,
then falls back to fuzzy matching if no containment is found.
Args:
ungrounded_ge_entities: Dictionary mapping entity types to lists of ungrounded entity names
grounded_ge_entities: Dictionary mapping entity types to lists of grounded entity names
fuzzy_match_threshold: Threshold for fuzzy matching (0-100)
Returns:
Dictionary mapping entity types to dictionaries of {original_entity: matched_entity}
"""
determined_entity_match_mapping: Dict[str, Dict[str, str]] = defaultdict(dict)
# For each entity type
for entity_type, determined_entities_list in determined_ge_entity_map.items():
ungrounded_list = determined_ge_entities_by_type.get(entity_type, [])
# Process each ungrounded entity
for ungrounded_entity in ungrounded_list:
if ungrounded_entity == "*":
continue
best_match = None
# First check if ungrounded entity is contained in or contains any grounded entities
for grounded_entity in determined_entities_list:
if (
ungrounded_entity.lower() in grounded_entity.lower()
or grounded_entity.lower() in ungrounded_entity.lower()
):
best_match = grounded_entity
break
# If still no match, fall back to fuzzy matching
if not best_match:
best_score = 0
# Try fuzzy matching with grounded entities
for grounded_entity in determined_entities_list:
score = fuzz.ratio(
ungrounded_entity.lower(), grounded_entity.lower()
)
if score > fuzzy_match_threshold and score > best_score:
best_match = grounded_entity
best_score = score
# Record the mapping
if best_match:
determined_entity_match_mapping[entity_type][
f"{ungrounded_entity}"
] = f"{best_match}"
else:
# No match found, this becomes a new unique entity
determined_entity_match_mapping[entity_type][
f"{ungrounded_entity}"
] = "Other"
# Log the results
logger.info("Entity matching results:")
for entity_type, mappings in determined_entity_match_mapping.items():
logger.info(f"\nEntity type: {entity_type}")
for original, matched in mappings.items():
if original != matched:
logger.info(f" Mapped: {original} -> {matched}")
else:
logger.info(f" New unique entity: {original}")
return determined_entity_match_mapping
def kg_clustering(
tenant_id: str, index_name: str, processing_chunk_batch_size: int = 8
) -> None:
"""
Here we will cluster the extractions based on their cluster frameworks.
Initially, this will only focus on grounded entities with pre-determined
relationships, so 'clustering' is actually not yet required.
However, we may need to reconcile entities coming from different sources.
The primary purpose of this function is to populate the actual KG tables
from the temp_extraction tables.
This will change with deep extraction, where grounded-sourceless entities
can be extracted and then need to be clustered.
"""
logger.info(f"Starting kg clustering for tenant {tenant_id}")
## Retrieval
source_documents_w_successful_transfers: Set[str] = set()
source_documents_w_failed_transfers: Set[str] = set()
primary_llm, fast_llm = get_default_llms()
with get_session_with_current_tenant() as db_session:
relationship_types = get_all_relationship_types(
db_session, kg_stage=KGStage.EXTRACTED
)
relationships = get_all_relationships(db_session, kg_stage=KGStage.EXTRACTED)
grounded_entities = get_entities_by_grounding(
db_session, KGStage.EXTRACTED, KGGroundingType.GROUNDED
)
## Clustering
# TODO: re-implement clustering of ungrounded entities as well as
# grounded entities that do not have a source document with deep extraction
# enabled!
# For now we would just create a trivial entity mapping from the
# 'unclustered' entities to the 'clustered' entities. So we can simply
# transfer the entity information from the Staging to the Normalized
# tables.
# This will be reimplemented when deep extraction is enabled.
## Database operations
# create the clustered objects - entities
transferred_entities: list[str] = []
for grounded_entity in grounded_entities:
with get_session_with_current_tenant() as db_session:
added_entity = add_entity(
db_session,
KGStage.NORMALIZED,
entity_type=grounded_entity.entity_type_id_name,
name=grounded_entity.name,
occurrences=grounded_entity.occurrences or 1,
document_id=grounded_entity.document_id or None,
attributes=grounded_entity.attributes or None,
)
db_session.commit()
if added_entity:
transferred_entities.append(added_entity.id_name)
transferred_relationship_types: list[str] = []
for relationship_type in relationship_types:
with get_session_with_current_tenant() as db_session:
added_relationship_type_id_name = add_relationship_type(
db_session,
KGStage.NORMALIZED,
source_entity_type=relationship_type.source_entity_type_id_name,
relationship_type=relationship_type.type,
target_entity_type=relationship_type.target_entity_type_id_name,
extraction_count=relationship_type.occurrences or 1,
)
db_session.commit()
transferred_relationship_types.append(added_relationship_type_id_name)
transferred_relationships: list[str] = []
for relationship in relationships:
with get_session_with_current_tenant() as db_session:
try:
added_relationship = add_relationship(
db_session,
KGStage.NORMALIZED,
relationship_id_name=relationship.id_name,
source_document_id=relationship.source_document or "",
occurrences=relationship.occurrences or 1,
)
if relationship.source_document:
source_documents_w_successful_transfers.add(
relationship.source_document
)
db_session.commit()
transferred_relationships.append(added_relationship.id_name)
except Exception as e:
if relationship.source_document:
source_documents_w_failed_transfers.add(
relationship.source_document
)
logger.error(
f"Error transferring relationship {relationship.id_name}: {e}"
)
# TODO: remove the /relationship types & entities that correspond to relationships
# source documents that failed to transfer. I.e, do a proper rollback
# TODO: update Vespa info when clustering/changes are performed
# delete the added objects from the staging tables
try:
with get_session_with_current_tenant() as db_session:
delete_relationships_by_id_names(
db_session, transferred_relationships, kg_stage=KGStage.EXTRACTED
)
db_session.commit()
except Exception as e:
logger.error(f"Error deleting relationships: {e}")
try:
with get_session_with_current_tenant() as db_session:
delete_relationship_types_by_id_names(
db_session, transferred_relationship_types, kg_stage=KGStage.EXTRACTED
)
db_session.commit()
except Exception as e:
logger.error(f"Error deleting relationship types: {e}")
try:
with get_session_with_current_tenant() as db_session:
delete_entities_by_id_names(
db_session, transferred_entities, kg_stage=KGStage.EXTRACTED
)
db_session.commit()
except Exception as e:
logger.error(f"Error deleting entities: {e}")
# Update document kg info
# with get_session_with_current_tenant() as db_session:
# all_kg_extracted_documents_info = get_all_kg_extracted_documents_info(
# db_session
# )
for document_id in source_documents_w_successful_transfers:
# Update the document kg info
with get_session_with_current_tenant() as db_session:
update_document_kg_info(
db_session,
document_id=document_id,
kg_stage=KGStage.NORMALIZED,
)
db_session.commit()

View File

@@ -0,0 +1,311 @@
from collections import defaultdict
from typing import Dict
from typing import List
from typing import Optional
import numpy as np
from thefuzz import fuzz # type: ignore
from thefuzz import process # type: ignore
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import get_entity_names_for_types
from onyx.db.relationships import get_relationships_for_entity_type_pairs
from onyx.kg.models import NormalizedEntities
from onyx.kg.models import NormalizedRelationships
from onyx.kg.models import NormalizedTerms
from onyx.kg.utils.embeddings import encode_string_batch
def _split_entity_type_v_name(entity: str) -> tuple[str, str]:
"""
Split an entity string into type and name.
"""
entity_split = entity.split(":")
if len(entity_split) < 2:
raise ValueError(f"Invalid entity: {entity}")
entity_type = entity_split[0]
entity_name = ":".join(entity_split[1:])
return entity_type, entity_name
def _get_existing_normalized_entities(
raw_entities: List[str],
) -> List[tuple[str, str | None]]:
"""
Get existing normalized entities from the database.
"""
entity_types = list(set([entity.split(":")[0] for entity in raw_entities]))
with get_session_with_current_tenant() as db_session:
entities = get_entity_names_for_types(db_session, entity_types)
return entities
def _get_existing_normalized_relationships(
raw_relationships: List[str],
) -> Dict[str, Dict[str, List[str]]]:
"""
Get existing normalized relationships from the database.
"""
relationship_type_map: Dict[str, Dict[str, List[str]]] = defaultdict(
lambda: defaultdict(list)
)
relationship_pairs = list(
set(
[
(
relationship.split("__")[0].split(":")[0],
relationship.split("__")[2].split(":")[0],
)
for relationship in raw_relationships
]
)
)
with get_session_with_current_tenant() as db_session:
relationships = get_relationships_for_entity_type_pairs(
db_session, relationship_pairs
)
for relationship in relationships:
relationship_type_map[relationship.source_entity_type_id_name][
relationship.target_entity_type_id_name
].append(relationship.id_name)
return relationship_type_map
def normalize_entities(
raw_entities_no_attributes: List[str],
) -> NormalizedEntities:
"""
Match each entity against a list of normalized entities using fuzzy matching.
Returns the best matching normalized entity for each input entity.
Args:
raw_entities_no_attributes: List of entity strings to normalize, w/o attributes
Returns:
List of normalized entity strings
"""
# TODO: this probably should move to a new Vespa schema for entity normalization
# TODO: as is, this should be converted to a generator going through the entities
# in large batches, to avoid memory issues
# Assume this is your predefined list of normalized entities
norm_entities = _get_existing_normalized_entities(raw_entities_no_attributes)
norm_entity_semantic_to_id_map: dict[str, dict[str, str]] = defaultdict(dict)
for norm_entity_tuple in norm_entities:
if norm_entity_tuple[1] is None:
continue
entity_type, norm_entity_semantic_name = _split_entity_type_v_name(
norm_entity_tuple[1]
)
norm_entity_semantic_to_id_map[entity_type][norm_entity_semantic_name] = (
_split_entity_type_v_name(norm_entity_tuple[0])[1]
)
normalized_results: List[str] = []
normalized_map: Dict[str, str | None] = {}
threshold = 80 # Adjust threshold as needed
base_norm_entities = [norm_entity[0] for norm_entity in norm_entities]
for entity in raw_entities_no_attributes:
entity_type, entity_name = entity.split(":")
if entity_name == "*":
normalized_results.append(entity)
normalized_map[entity] = entity
continue
# Find the best match and its score from norm_entities
all_entity_match_possibilities = norm_entity_semantic_to_id_map[
entity_type
].keys()
best_match, score = process.extractOne(
entity_name, all_entity_match_possibilities, scorer=fuzz.partial_ratio
)
if best_match not in base_norm_entities:
best_match_id = norm_entity_semantic_to_id_map[entity_type][
best_match
] # replace semantic_id with the actual id
else:
best_match_id = best_match
if score >= threshold:
normalized_results.append(f"{entity_type}:{best_match_id}")
normalized_map[entity] = f"{entity_type}:{best_match_id}"
else:
# If no good match found, keep original
normalized_map[entity] = entity
return NormalizedEntities(
entities=normalized_results, entity_normalization_map=normalized_map
)
def normalize_entities_w_attributes_from_map(
raw_entities_w_attributes: List[str],
entity_normalization_map: Dict[str, Optional[str]],
) -> List[str]:
"""
Normalize entities with attributes using the entity normalization map.
"""
normalized_entities_w_attributes: List[str] = []
for raw_entities_w_attribute in raw_entities_w_attributes:
assert (
len(raw_entities_w_attribute.split("--")) == 2
), f"Invalid entity with attributes: {raw_entities_w_attribute}"
raw_entity, attributes = raw_entities_w_attribute.split("--")
normalized_entity = entity_normalization_map.get(raw_entity.strip())
if normalized_entity is None:
continue
else:
normalized_entities_w_attributes.append(
f"{normalized_entity}--{raw_entities_w_attribute.split('--')[1].strip()}"
)
return normalized_entities_w_attributes
def normalize_relationships(
raw_relationships: List[str], entity_normalization_map: Dict[str, Optional[str]]
) -> NormalizedRelationships:
"""
Normalize relationships using entity mappings and relationship string matching.
Args:
relationships: List of relationships in format "source__relation__target"
entity_normalization_map: Mapping of raw entities to normalized ones (or None)
Returns:
NormalizedRelationships containing normalized relationships and mapping
"""
# Placeholder for normalized relationship structure
nor_relationships = _get_existing_normalized_relationships(raw_relationships)
normalized_rels: List[str] = []
normalization_map: Dict[str, str | None] = {}
for raw_rel in raw_relationships:
# 1. Split and normalize entities
try:
source, rel_string, target = raw_rel.split("__")
except ValueError:
raise ValueError(f"Invalid relationship format: {raw_rel}")
# Check if entities are in normalization map and not None
norm_source = entity_normalization_map.get(source)
norm_target = entity_normalization_map.get(target)
if norm_source is None or norm_target is None:
normalization_map[raw_rel] = None
continue
# 2. Find candidate normalized relationships
candidate_rels = []
norm_source_type = norm_source.split(":")[0]
norm_target_type = norm_target.split(":")[0]
if (
norm_source_type in nor_relationships
and norm_target_type in nor_relationships[norm_source_type]
):
candidate_rels = [
rel.split("__")[1]
for rel in nor_relationships[norm_source_type][norm_target_type]
]
if not candidate_rels:
normalization_map[raw_rel] = None
continue
# 3. Encode and find best match
strings_to_encode = [rel_string] + candidate_rels
vectors = encode_string_batch(strings_to_encode)
# Get raw relation vector and candidate vectors
raw_vector = vectors[0]
candidate_vectors = vectors[1:]
# Calculate dot products
dot_products = np.dot(candidate_vectors, raw_vector)
best_match_idx = np.argmax(dot_products)
# Create normalized relationship
norm_rel = f"{norm_source}__{candidate_rels[best_match_idx]}__{norm_target}"
normalized_rels.append(norm_rel)
normalization_map[raw_rel] = norm_rel
return NormalizedRelationships(
relationships=normalized_rels, relationship_normalization_map=normalization_map
)
def normalize_terms(raw_terms: List[str]) -> NormalizedTerms:
"""
Normalize terms using semantic similarity matching.
Args:
terms: List of terms to normalize
Returns:
NormalizedTerms containing normalized terms and mapping
"""
# # Placeholder for normalized terms - this would typically come from a predefined list
# normalized_term_list = [
# "algorithm",
# "database",
# "software",
# "programming",
# # ... other normalized terms ...
# ]
# normalized_terms: List[str] = []
# normalization_map: Dict[str, str | None] = {}
# if not raw_terms:
# return NormalizedTerms(terms=[], term_normalization_map={})
# # Encode all terms at once for efficiency
# strings_to_encode = raw_terms + normalized_term_list
# vectors = encode_string_batch(strings_to_encode)
# # Split vectors into query terms and candidate terms
# query_vectors = vectors[:len(raw_terms)]
# candidate_vectors = vectors[len(raw_terms):]
# # Calculate similarity for each term
# for i, term in enumerate(raw_terms):
# # Calculate dot products with all candidates
# similarities = np.dot(candidate_vectors, query_vectors[i])
# best_match_idx = np.argmax(similarities)
# best_match_score = similarities[best_match_idx]
# # Use a threshold to determine if the match is good enough
# if best_match_score > 0.7: # Adjust threshold as needed
# normalized_term = normalized_term_list[best_match_idx]
# normalized_terms.append(normalized_term)
# normalization_map[term] = normalized_term
# else:
# # If no good match found, keep original
# normalization_map[term] = None
# return NormalizedTerms(
# terms=normalized_terms,
# term_normalization_map=normalization_map
# )
return NormalizedTerms(
terms=raw_terms, term_normalization_map={term: term for term in raw_terms}
)

View File

@@ -0,0 +1,53 @@
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entity_type import populate_default_employee_account_information
from onyx.db.entity_type import (
populate_default_primary_grounded_entity_type_information,
)
from onyx.db.kg_config import get_kg_enablement
from onyx.db.kg_config import KGConfigSettings
from onyx.utils.logger import setup_logger
logger = setup_logger()
def populate_default_grounded_entity_types() -> None:
with get_session_with_current_tenant() as db_session:
if not get_kg_enablement(db_session):
logger.error(
"KG approach is not enabled, the entity types cannot be populated."
)
raise ValueError(
"KG approach is not enabled, the entity types cannot be populated."
)
populate_default_primary_grounded_entity_type_information(db_session)
db_session.commit()
return None
def populate_default_account_employee_definitions() -> None:
with get_session_with_current_tenant() as db_session:
if not get_kg_enablement(db_session):
logger.error(
"KG approach is not enabled, the entity types cannot be populated."
)
raise ValueError(
"KG approach is not enabled, the entity types cannot be populated."
)
populate_default_employee_account_information(db_session)
db_session.commit()
return None
def execute_kg_setting_tests(kg_config_settings: KGConfigSettings) -> None:
if not kg_config_settings.KG_ENABLED:
raise ValueError("KG is not enabled")
if not kg_config_settings.KG_VENDOR:
raise ValueError("KG_VENDOR is not set")
if not kg_config_settings.KG_VENDOR_DOMAINS:
raise ValueError("KG_VENDOR_DOMAINS is not set")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,84 @@
from pydantic import BaseModel
from onyx.kg.models import KGDefaultEntityDefinition
from onyx.kg.models import KGGroundingType
class KGDefaultPrimaryGroundedEntityDefinitions(BaseModel):
LINEAR: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A formal ticket about a product issue or improvement request.",
grounding=KGGroundingType.GROUNDED,
grounded_source_name="linear",
)
FIREFLIES: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A phone call transcript between us (---vendor_name---) \
and another account or individuals, or an internal meeting.",
grounding=KGGroundingType.GROUNDED,
grounded_source_name="fireflies",
)
GONG: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A phone call transcript between us (---vendor_name---) \
and another account or individuals, or an internal meeting.",
grounding=KGGroundingType.GROUNDED,
grounded_source_name="gong",
)
SLACK: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A Slack conversation.",
grounding=KGGroundingType.GROUNDED,
grounded_source_name="slack",
)
WEB: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A web page.",
grounding=KGGroundingType.GROUNDED,
grounded_source_name="web",
)
GOOGLE_DRIVE: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A Google Drive document.",
grounding=KGGroundingType.GROUNDED,
grounded_source_name="google_drive",
)
GMAIL: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="An email.",
grounding=KGGroundingType.GROUNDED,
grounded_source_name="gmail",
)
JIRA: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A formal JIRA ticket about a product issue or improvement request.",
grounding=KGGroundingType.GROUNDED,
grounded_source_name="jira",
)
class KGDefaultAccountEmployeeDefinitions(BaseModel):
VENDOR: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="The Vendor ---vendor_name---, 'us'",
grounding=KGGroundingType.GROUNDED,
active=False,
grounded_source_name=None,
)
ACCOUNT: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A company that was, is, or potentially could be a customer of the vendor \
('us, ---vendor_name---'). Note that ---vendor_name--- can never be an ACCOUNT.",
grounding=KGGroundingType.GROUNDED,
active=False,
grounded_source_name=None,
)
EMPLOYEE: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
description="A person who speaks on \
behalf of 'our' company (the VENDOR ---vendor_name---), NOT of another account. Therefore, employees of other companies \
are NOT included here. If in doubt, do NOT extract.",
grounding=KGGroundingType.GROUNDED,
active=False,
grounded_source_name=None,
)

219
backend/onyx/kg/models.py Normal file
View File

@@ -0,0 +1,219 @@
from enum import Enum
from typing import Any
from typing import Dict
from pydantic import BaseModel
class KGConfigSettings(BaseModel):
KG_ENABLED: bool = False
KG_VENDOR: str | None = None
KG_VENDOR_DOMAINS: list[str] | None = None
KG_IGNORE_EMAIL_DOMAINS: list[str] | None = None
class KGConfigVars(str, Enum):
KG_ENABLED = "KG_ENABLED"
KG_VENDOR = "KG_VENDOR"
KG_VENDOR_DOMAINS = "KG_VENDOR_DOMAINS"
KG_IGNORE_EMAIL_DOMAINS = "KG_IGNORE_EMAIL_DOMAINS"
class KGChunkFormat(BaseModel):
connector_id: int | None = None
document_id: str
chunk_id: int
title: str
content: str
primary_owners: list[str]
secondary_owners: list[str]
source_type: str
metadata: dict[str, str | list[str]] | None = None
entities: Dict[str, int] = {}
relationships: Dict[str, int] = {}
terms: Dict[str, int] = {}
deep_extraction: bool = False
class KGChunkExtraction(BaseModel):
connector_id: int
document_id: str
chunk_id: int
core_entity: str
entities: list[str]
relationships: list[str]
terms: list[str]
attributes: dict[str, str | list[str]]
class KGChunkId(BaseModel):
connector_id: int | None = None
document_id: str
chunk_id: int
class KGRelationshipExtraction(BaseModel):
relationship_str: str
source_document_id: str
class KGAggregatedExtractions(BaseModel):
grounded_entities_document_ids: dict[str, str]
entities: dict[str, int]
relationships: dict[str, dict[str, int]]
terms: dict[str, int]
attributes: dict[str, dict[str, str | list[str]]]
class KGBatchExtractionStats(BaseModel):
connector_id: int | None = None
succeeded: list[KGChunkId]
failed: list[KGChunkId]
aggregated_kg_extractions: KGAggregatedExtractions
class ConnectorExtractionStats(BaseModel):
connector_id: int
num_succeeded: int
num_failed: int
num_processed: int
class KGPerson(BaseModel):
name: str
company: str
employee: bool
class NormalizedEntities(BaseModel):
entities: list[str]
entity_normalization_map: dict[str, str | None]
class NormalizedRelationships(BaseModel):
relationships: list[str]
relationship_normalization_map: dict[str, str | None]
class NormalizedTerms(BaseModel):
terms: list[str]
term_normalization_map: dict[str, str | None]
class KGClassificationContent(BaseModel):
document_id: str
classification_content: str
source_type: str
source_metadata: dict[str, Any] | None = None
entity_type: str | None = None
metadata: dict[str, Any] | None = None
class KGEnrichedClassificationContent(KGClassificationContent):
classification_enabled: bool
classification_instructions: dict[str, Any]
deep_extraction: bool
class KGClassificationDecisions(BaseModel):
document_id: str
classification_decision: bool
classification_class: str | None
source_metadata: dict[str, Any] | None = None
class KGClassificationRule(BaseModel):
description: str
extration: bool
class KGClassificationInstructions(BaseModel):
classification_enabled: bool
classification_options: str
classification_class_definitions: dict[str, Dict[str, str | bool]]
class KGExtractionInstructions(BaseModel):
deep_extraction: bool
active: bool
class KGEntityTypeInstructions(BaseModel):
classification_instructions: KGClassificationInstructions
extraction_instructions: KGExtractionInstructions
filter_instructions: dict[str, Any] | None = None
class KGEnhancedDocumentMetadata(BaseModel):
entity_type: str | None
document_attributes: dict[str, Any] | None
deep_extraction: bool
classification_enabled: bool
classification_instructions: KGClassificationInstructions | None
skip: bool
class ContextPreparation(BaseModel):
"""
Context preparation format for the LLM KG extraction.
"""
llm_context: str
core_entity: str
implied_entities: list[str]
implied_relationships: list[str]
implied_terms: list[str]
class KGDocumentClassificationPrompt(BaseModel):
"""
Document classification prompt format for the LLM KG extraction.
"""
llm_prompt: str | None
class KGConnectorData(BaseModel):
id: int
source: str
class KGStage(str, Enum):
EXTRACTED = "extracted"
NORMALIZED = "normalized"
FAILED = "failed"
SKIPPED = "skipped"
NOT_STARTED = "not_started"
EXTRACTING = "extracting"
DO_NOT_EXTRACT = "do_not_extract"
class KGDocumentEntitiesRelationshipsAttributes(BaseModel):
kg_core_document_id_name: str
implied_entities: set[str]
implied_relationships: set[str]
converted_relationships_to_attributes: dict[str, list[str]]
company_participant_emails: set[str]
account_participant_emails: set[str]
converted_attributes_to_relationships: set[str]
document_attributes: dict[str, Any] | None
class KGGroundingType(str, Enum):
UNGROUNDED = "ungrounded"
GROUNDED = "grounded"
class KGDefaultEntityDefinition(BaseModel):
description: str
grounding: KGGroundingType
active: bool = False
grounded_source_name: str | None
attributes: dict = {}
entity_values: dict = {}
class KGEntityInformation(BaseModel):
entity_type: str
entity_name: str
occurences: int

View File

@@ -0,0 +1 @@
# TODO: Implement this

View File

@@ -0,0 +1,21 @@
from onyx.db.document import update_document_kg_stages
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import KGEntityExtractionStaging
from onyx.db.models import KGRelationshipExtractionStaging
from onyx.db.models import KGRelationshipTypeExtractionStaging
from onyx.kg.models import KGStage
def reset_extraction_kg_index() -> None:
"""
Resets the knowledge graph index.
"""
with get_session_with_current_tenant() as db_session:
db_session.query(KGRelationshipExtractionStaging).delete()
db_session.query(KGEntityExtractionStaging).delete()
db_session.query(KGRelationshipTypeExtractionStaging).delete()
db_session.commit()
with get_session_with_current_tenant() as db_session:
update_document_kg_stages(db_session, KGStage.EXTRACTED, KGStage.NOT_STARTED)
db_session.commit()

View File

@@ -0,0 +1,26 @@
from onyx.db.document import reset_all_document_kg_stages
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import KGEntity
from onyx.db.models import KGEntityExtractionStaging
from onyx.db.models import KGRelationship
from onyx.db.models import KGRelationshipExtractionStaging
from onyx.db.models import KGRelationshipType
from onyx.db.models import KGRelationshipTypeExtractionStaging
def reset_full_kg_index() -> None:
"""
Resets the knowledge graph index.
"""
with get_session_with_current_tenant() as db_session:
db_session.query(KGRelationship).delete()
db_session.query(KGRelationshipType).delete()
db_session.query(KGEntity).delete()
db_session.query(KGRelationshipExtractionStaging).delete()
db_session.query(KGEntityExtractionStaging).delete()
db_session.query(KGRelationshipTypeExtractionStaging).delete()
db_session.commit()
with get_session_with_current_tenant() as db_session:
reset_all_document_kg_stages(db_session)
db_session.commit()

View File

@@ -0,0 +1,22 @@
from onyx.db.document import update_document_kg_stages
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import KGEntity
from onyx.db.models import KGRelationship
from onyx.db.models import KGRelationshipType
from onyx.kg.models import KGStage
def reset_normalization_kg_index() -> None:
"""
Resets the knowledge graph index.
"""
with get_session_with_current_tenant() as db_session:
db_session.query(KGRelationship).delete()
db_session.query(KGEntity).delete()
db_session.query(KGRelationshipType).delete()
db_session.commit()
with get_session_with_current_tenant() as db_session:
update_document_kg_stages(db_session, KGStage.NORMALIZED, KGStage.EXTRACTED)
db_session.commit()

View File

@@ -0,0 +1,23 @@
from typing import List
import numpy as np
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.search_settings import get_current_search_settings
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import EmbedTextType
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
def encode_string_batch(strings: List[str]) -> np.ndarray:
with get_session_with_current_tenant() as db_session:
current_search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=current_search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
# Get embeddings while session is still open
embedding = model.encode(strings, text_type=EmbedTextType.QUERY)
return np.array(embedding)

View File

@@ -0,0 +1,410 @@
import re
from collections import defaultdict
from typing import Dict
from onyx.configs.constants import OnyxCallTypes
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.entities import get_kg_entity_by_document
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.kg_config import KGConfigSettings
from onyx.db.models import Document
from onyx.kg.models import KGChunkFormat
from onyx.kg.models import KGClassificationContent
from onyx.kg.models import (
KGDocumentClassificationPrompt,
)
from onyx.kg.models import KGDocumentEntitiesRelationshipsAttributes
from onyx.kg.models import KGEnhancedDocumentMetadata
from onyx.kg.utils.formatting_utils import generalize_entities
from onyx.kg.utils.formatting_utils import kg_email_processing
from onyx.prompts.kg_prompts import CALL_CHUNK_PREPROCESSING_PROMPT
from onyx.prompts.kg_prompts import CALL_DOCUMENT_CLASSIFICATION_PROMPT
from onyx.prompts.kg_prompts import GENERAL_CHUNK_PREPROCESSING_PROMPT
def _update_implied_entities_relationships(
kg_core_document_id_name: str,
owner_list: list[str],
implied_entities: set[str],
implied_relationships: set[str],
company_participant_emails: set[str],
account_participant_emails: set[str],
relationship_type: str,
kg_config_settings: KGConfigSettings,
converted_relationships_to_attributes: dict[str, list[str]],
) -> tuple[set[str], set[str], set[str], set[str], dict[str, list[str]]]:
for owner in owner_list or []:
if is_email(owner):
(
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
) = kg_process_person(
owner,
kg_core_document_id_name,
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
relationship_type,
kg_config_settings,
)
else:
converted_relationships_to_attributes[relationship_type].append(owner)
return (
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
converted_relationships_to_attributes,
)
def kg_document_entities_relationships_attribute_generation(
document: Document,
doc_metadata: KGEnhancedDocumentMetadata,
active_entities: list[str],
kg_config_settings: KGConfigSettings,
) -> KGDocumentEntitiesRelationshipsAttributes:
"""
Generate entities, relationships, and attributes for a document.
"""
# Get document entity type from the KGEnhancedDocumentMetadata
document_entity_type = doc_metadata.entity_type
assert document_entity_type is not None
# Get additional document attributes from the KGEnhancedDocumentMetadata
document_attributes = doc_metadata.document_attributes
implied_entities: set[str] = set()
implied_relationships: set[str] = (
set()
) # 'Relationships' that will be captured as KG relationships
converted_relationships_to_attributes: dict[str, list[str]] = defaultdict(
list
) # 'Relationships' that will be captured as KG entity attributes
converted_attributes_to_relationships: set[str] = (
set()
) # Attributes that should be captures as entities and then relationships (Account = ...)
company_participant_emails: set[str] = (
set()
) # Quantity needed for call processing - participants from vendor
account_participant_emails: set[str] = (
set()
) # Quantity needed for call processing - external participants
# Chunk treatment variables
document_is_from_call = document_entity_type.lower() in [
call_type.value.lower() for call_type in OnyxCallTypes
]
# Get core entity
document_id = document.id
primary_owners = document.primary_owners
secondary_owners = document.secondary_owners
with get_session_with_current_tenant() as db_session:
kg_core_document = get_kg_entity_by_document(db_session, document_id)
if kg_core_document:
kg_core_document_id_name = kg_core_document.id_name
else:
kg_core_document_id_name = f"{document_entity_type.upper()}:{document_id}"
# Get implied entities and relationships from primary/secondary owners
if document_is_from_call:
(
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
converted_relationships_to_attributes,
) = _update_implied_entities_relationships(
kg_core_document_id_name,
owner_list=(primary_owners or []) + (secondary_owners or []),
implied_entities=implied_entities,
implied_relationships=implied_relationships,
company_participant_emails=company_participant_emails,
account_participant_emails=account_participant_emails,
relationship_type="participates_in",
kg_config_settings=kg_config_settings,
converted_relationships_to_attributes=converted_relationships_to_attributes,
)
else:
(
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
converted_relationships_to_attributes,
) = _update_implied_entities_relationships(
kg_core_document_id_name,
owner_list=primary_owners or [],
implied_entities=implied_entities,
implied_relationships=implied_relationships,
company_participant_emails=company_participant_emails,
account_participant_emails=account_participant_emails,
relationship_type="leads",
kg_config_settings=kg_config_settings,
converted_relationships_to_attributes=converted_relationships_to_attributes,
)
(
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
converted_relationships_to_attributes,
) = _update_implied_entities_relationships(
kg_core_document_id_name,
owner_list=secondary_owners or [],
implied_entities=implied_entities,
implied_relationships=implied_relationships,
company_participant_emails=company_participant_emails,
account_participant_emails=account_participant_emails,
relationship_type="participates_in",
kg_config_settings=kg_config_settings,
converted_relationships_to_attributes=converted_relationships_to_attributes,
)
if document_attributes is not None:
cleaned_document_attributes = document_attributes.copy()
for attribute, value in document_attributes.items():
if attribute.lower() in [x.lower() for x in active_entities]:
converted_attributes_to_relationships.add(attribute)
if isinstance(value, str):
implied_entity = f"{attribute.upper()}:{value.capitalize()}"
implied_entities.add(implied_entity)
implied_relationships.add(
f"{implied_entity}__is_{attribute.lower()}_of__{kg_core_document_id_name}"
)
implied_entity = f"{attribute.upper()}:*"
implied_entities.add(implied_entity)
implied_relationships.add(
f"{implied_entity}__is_{attribute.lower()}_of__{kg_core_document_id_name}"
)
implied_entity = f"{attribute.upper()}:*"
implied_entities.add(implied_entity)
implied_relationships.add(
f"{implied_entity}__is_{attribute.lower()}_of__{document_entity_type.upper()}:*"
)
implied_entity = f"{attribute.upper()}:{value.capitalize()}"
implied_entities.add(implied_entity)
implied_relationships.add(
f"{implied_entity}__is_{attribute.lower()}_of__{document_entity_type.upper()}:*"
)
cleaned_document_attributes.pop(attribute)
elif isinstance(value, list):
for item in value:
implied_entity = f"{attribute.upper()}:{item.capitalize()}"
implied_entities.add(implied_entity)
implied_relationships.add(
f"{implied_entity}__is_{attribute.lower()}_of__{kg_core_document_id_name}"
)
cleaned_document_attributes.pop(attribute)
if attribute.lower().endswith("_id") or attribute.endswith("Id"):
cleaned_document_attributes.pop(attribute)
else:
cleaned_document_attributes = None
return KGDocumentEntitiesRelationshipsAttributes(
kg_core_document_id_name=kg_core_document_id_name,
implied_entities=implied_entities,
implied_relationships=implied_relationships,
company_participant_emails=company_participant_emails,
account_participant_emails=account_participant_emails,
converted_relationships_to_attributes=converted_relationships_to_attributes,
converted_attributes_to_relationships=converted_attributes_to_relationships,
document_attributes=cleaned_document_attributes,
)
def _prepare_llm_document_content_call(
document_classification_content: KGClassificationContent,
category_list: str,
category_definition_string: str,
kg_config_settings: KGConfigSettings,
) -> KGDocumentClassificationPrompt:
"""
Calls - prepare prompt for the LLM classification.
"""
prompt = CALL_DOCUMENT_CLASSIFICATION_PROMPT.format(
beginning_of_call_content=document_classification_content.classification_content,
category_list=category_list,
category_options=category_definition_string,
vendor=kg_config_settings.KG_VENDOR,
)
return KGDocumentClassificationPrompt(
llm_prompt=prompt,
)
def kg_process_person(
person: str,
core_document_id_name: str,
implied_entities: set[str],
implied_relationships: set[str],
company_participant_emails: set[str],
account_participant_emails: set[str],
relationship_type: str,
kg_config_settings: KGConfigSettings,
) -> tuple[set[str], set[str], set[str], set[str]]:
"""
Process a single owner and return updated sets with entities and relationships.
Returns:
tuple containing (implied_entities, implied_relationships, company_participant_emails, account_participant_emails)
"""
with get_session_with_current_tenant() as db_session:
kg_config_settings = get_kg_config_settings(db_session)
if not kg_config_settings.KG_ENABLED:
raise ValueError("KG is not enabled")
assert isinstance(kg_config_settings.KG_IGNORE_EMAIL_DOMAINS, list)
kg_person = kg_email_processing(person, kg_config_settings)
if any(
domain.lower() in kg_person.company.lower()
for domain in kg_config_settings.KG_IGNORE_EMAIL_DOMAINS
):
return (
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
)
if kg_person.employee:
company_participant_emails = company_participant_emails | {
f"{kg_person.name} -- ({kg_person.company})"
}
if kg_person.name not in implied_entities:
generalized_target_entity = list(
generalize_entities([core_document_id_name])
)[0]
implied_entities = implied_entities | {f"EMPLOYEE:{kg_person.name}"}
implied_relationships = implied_relationships | {
f"EMPLOYEE:{kg_person.name}__{relationship_type}__{core_document_id_name}",
f"EMPLOYEE:{kg_person.name}__{relationship_type}__{generalized_target_entity}",
f"EMPLOYEE:*__{relationship_type}__{core_document_id_name}",
f"EMPLOYEE:*__{relationship_type}__{generalized_target_entity}",
}
if kg_person.company not in implied_entities:
implied_entities = implied_entities | {f"VENDOR:{kg_person.company}"}
implied_relationships = implied_relationships | {
f"VENDOR:{kg_person.company}__{relationship_type}__{core_document_id_name}",
f"VENDOR:{kg_person.company}__{relationship_type}__{generalized_target_entity}",
}
else:
account_participant_emails = account_participant_emails | {
f"{kg_person.name} -- ({kg_person.company})"
}
if kg_person.company not in implied_entities:
implied_entities = implied_entities | {
f"ACCOUNT:{kg_person.company}",
"ACCOUNT:*",
}
implied_relationships = implied_relationships | {
f"ACCOUNT:{kg_person.company}__{relationship_type}__{core_document_id_name}",
f"ACCOUNT:*__{relationship_type}__{core_document_id_name}",
}
generalized_target_entity = list(
generalize_entities([core_document_id_name])
)[0]
implied_relationships = implied_relationships | {
f"ACCOUNT:*__{relationship_type}__{generalized_target_entity}",
f"ACCOUNT:{kg_person.company}__{relationship_type}__{generalized_target_entity}",
}
return (
implied_entities,
implied_relationships,
company_participant_emails,
account_participant_emails,
)
def prepare_llm_content_extraction(
chunk: KGChunkFormat,
company_participant_emails: set[str],
account_participant_emails: set[str],
kg_config_settings: KGConfigSettings,
) -> str:
chunk_is_from_call = chunk.source_type.lower() in [
call_type.value.lower() for call_type in OnyxCallTypes
]
if chunk_is_from_call:
llm_context = CALL_CHUNK_PREPROCESSING_PROMPT.format(
participant_string=company_participant_emails,
account_participant_string=account_participant_emails,
vendor=kg_config_settings.KG_VENDOR,
content=chunk.content,
)
else:
llm_context = GENERAL_CHUNK_PREPROCESSING_PROMPT.format(
content=chunk.content,
vendor=kg_config_settings.KG_VENDOR,
)
return llm_context
def prepare_llm_document_content(
document_classification_content: KGClassificationContent,
category_list: str,
category_definitions: dict[str, Dict[str, str | bool]],
kg_config_settings: KGConfigSettings,
) -> KGDocumentClassificationPrompt:
"""
Prepare the content for the extraction classification.
"""
category_definition_string = ""
for category, category_data in category_definitions.items():
category_definition_string += f"{category}: {category_data['description']}\n"
if document_classification_content.source_type.lower() in [
call_type.value.lower() for call_type in OnyxCallTypes
]:
return _prepare_llm_document_content_call(
document_classification_content,
category_list,
category_definition_string,
kg_config_settings,
)
else:
return KGDocumentClassificationPrompt(
llm_prompt=None,
)
def is_email(email: str) -> bool:
"""
Check if a string is a valid email address.
"""
return re.match(r"[^@]+@[^@]+\.[^@]+", email) is not None

View File

@@ -0,0 +1,139 @@
from collections import defaultdict
from onyx.db.kg_config import KGConfigSettings
from onyx.kg.models import KGAggregatedExtractions
from onyx.kg.models import KGPerson
def format_entity(entity: str) -> str:
if len(entity.split(":")) == 2:
entity_type, entity_name = entity.split(":")
return f"{entity_type.upper()}:{entity_name.title()}"
else:
return entity
def format_relationship(relationship: str) -> str:
source_node, relationship_type, target_node = relationship.split("__")
return (
f"{format_entity(source_node)}__"
f"{relationship_type.lower()}__"
f"{format_entity(target_node)}"
)
def format_relationship_type(relationship_type: str) -> str:
source_node_type, relationship_type, target_node_type = relationship_type.split(
"__"
)
return (
f"{source_node_type.upper()}__"
f"{relationship_type.lower()}__"
f"{target_node_type.upper()}"
)
def generate_relationship_type(relationship: str) -> str:
source_node, relationship_type, target_node = relationship.split("__")
return (
f"{source_node.split(':')[0].upper()}__"
f"{relationship_type.lower()}__"
f"{target_node.split(':')[0].upper()}"
)
def aggregate_kg_extractions(
connector_aggregated_kg_extractions_list: list[KGAggregatedExtractions],
) -> KGAggregatedExtractions:
aggregated_kg_extractions = KGAggregatedExtractions(
grounded_entities_document_ids=defaultdict(str),
entities=defaultdict(int),
relationships=defaultdict(lambda: defaultdict(int)),
terms=defaultdict(int),
attributes=defaultdict(dict),
)
for connector_aggregated_kg_extractions in connector_aggregated_kg_extractions_list:
for (
grounded_entity,
document_id,
) in connector_aggregated_kg_extractions.grounded_entities_document_ids.items():
aggregated_kg_extractions.grounded_entities_document_ids[
grounded_entity
] = document_id
for entity, count in connector_aggregated_kg_extractions.entities.items():
if entity not in aggregated_kg_extractions.entities:
aggregated_kg_extractions.entities[entity] = count
else:
aggregated_kg_extractions.entities[entity] += count
for (
relationship,
relationship_data,
) in connector_aggregated_kg_extractions.relationships.items():
for source_document_id, count in relationship_data.items():
if relationship not in aggregated_kg_extractions.relationships:
aggregated_kg_extractions.relationships[relationship] = defaultdict(
int
)
aggregated_kg_extractions.relationships[relationship][
source_document_id
] += count
for term, count in connector_aggregated_kg_extractions.terms.items():
if term not in aggregated_kg_extractions.terms:
aggregated_kg_extractions.terms[term] = count
else:
aggregated_kg_extractions.terms[term] += count
return aggregated_kg_extractions
def kg_email_processing(email: str, kg_config_settings: KGConfigSettings) -> KGPerson:
"""
Process the email.
"""
name, company_domain = email.split("@")
assert isinstance(company_domain, str)
assert isinstance(kg_config_settings.KG_VENDOR_DOMAINS, list)
assert isinstance(kg_config_settings.KG_VENDOR, str)
employee = any(
domain in company_domain for domain in kg_config_settings.KG_VENDOR_DOMAINS
)
if employee:
company = kg_config_settings.KG_VENDOR
else:
company = company_domain.capitalize()
return KGPerson(name=name, company=company, employee=employee)
def generalize_entities(entities: list[str]) -> set[str]:
"""
Generalize entities to their superclass.
"""
return set([f"{entity.split(':')[0]}:*" for entity in entities])
def generalize_relationships(relationships: list[str]) -> set[str]:
"""
Generalize relationships to their superclass.
"""
generalized_relationships: set[str] = set()
for relationship in relationships:
assert (
len(relationship.split("__")) == 3
), "Relationship is not in the correct format"
source_entity, relationship_type, target_entity = relationship.split("__")
generalized_source_entity = list(generalize_entities([source_entity]))[0]
generalized_target_entity = list(generalize_entities([target_entity]))[0]
generalized_relationships.add(
f"{generalized_source_entity}__{relationship_type}__{target_entity}"
)
generalized_relationships.add(
f"{source_entity}__{relationship_type}__{generalized_target_entity}"
)
generalized_relationships.add(
f"{generalized_source_entity}__{relationship_type}__{generalized_target_entity}"
)
return generalized_relationships

View File

@@ -0,0 +1,244 @@
import json
from collections.abc import Generator
from onyx.configs.constants import OnyxCallTypes
from onyx.db.kg_config import KGConfigSettings
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
from onyx.document_index.vespa.index import IndexFilters
from onyx.kg.models import KGChunkFormat
from onyx.kg.models import KGClassificationContent
from onyx.kg.utils.formatting_utils import kg_email_processing
def get_document_classification_content_for_kg_processing(
document_ids: list[str],
source: str,
index_name: str,
kg_config_settings: KGConfigSettings,
batch_size: int = 8,
num_classification_chunks: int = 3,
entity_type: str | None = None,
) -> Generator[list[KGClassificationContent], None, None]:
"""
Generates the content used for initial classification of a document from
the first num_classification_chunks chunks.
"""
classification_content_list: list[KGClassificationContent] = []
for i in range(0, len(document_ids), batch_size):
batch_document_ids = document_ids[i : i + batch_size]
for document_id in batch_document_ids:
# ... existing code for getting chunks and processing ...
first_num_classification_chunks: list[dict] = _get_chunks_via_visit_api(
chunk_request=VespaChunkRequest(
document_id=document_id,
max_chunk_ind=num_classification_chunks - 1,
min_chunk_ind=0,
),
index_name=index_name,
filters=IndexFilters(access_control_list=None),
field_names=[
"document_id",
"chunk_id",
"title",
"content",
"metadata",
"source_type",
"primary_owners",
"secondary_owners",
],
get_large_chunks=False,
)
first_num_classification_chunks = sorted(
first_num_classification_chunks, key=lambda x: x["fields"]["chunk_id"]
)[:num_classification_chunks]
classification_content = _get_classification_content_from_chunks(
first_num_classification_chunks,
kg_config_settings,
)
metadata = first_num_classification_chunks[0]["fields"]["metadata"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
assert isinstance(metadata, dict) or metadata is None
classification_content_list.append(
KGClassificationContent(
document_id=document_id,
classification_content=classification_content,
source_type=first_num_classification_chunks[0]["fields"][
"source_type"
],
source_metadata=metadata,
entity_type=entity_type,
)
)
# Yield the batch of classification content
if classification_content_list:
yield classification_content_list
classification_content_list = []
# Yield any remaining items
if classification_content_list:
yield classification_content_list
def get_document_chunks_for_kg_processing(
document_id: str,
deep_extraction: bool,
index_name: str,
batch_size: int = 8,
) -> Generator[list[KGChunkFormat], None, None]:
"""
Retrieves chunks from Vespa for the given document IDs and converts them to KGChunks.
Args:
document_ids (list[str]): List of document IDs to fetch chunks for
index_name (str): Name of the Vespa index
tenant_id (str): ID of the tenant
Yields:
list[KGChunk]: Batches of chunks ready for KG processing
"""
current_batch: list[KGChunkFormat] = []
# get all chunks for the document
chunks = _get_chunks_via_visit_api(
chunk_request=VespaChunkRequest(document_id=document_id),
index_name=index_name,
filters=IndexFilters(access_control_list=None),
field_names=[
"document_id",
"chunk_id",
"title",
"content",
"metadata",
"primary_owners",
"secondary_owners",
"source_type",
"kg_entities",
"kg_relationships",
"kg_terms",
],
get_large_chunks=False,
)
# Convert Vespa chunks to KGChunks
# kg_chunks: list[KGChunkFormat] = []
for i, chunk in enumerate(chunks):
fields = chunk["fields"]
if isinstance(fields.get("metadata", {}), str):
fields["metadata"] = json.loads(fields["metadata"])
current_batch.append(
KGChunkFormat(
connector_id=None, # We may need to adjust this
document_id=fields.get("document_id"),
chunk_id=fields.get("chunk_id"),
primary_owners=fields.get("primary_owners", []),
secondary_owners=fields.get("secondary_owners", []),
source_type=fields.get("source_type", ""),
title=fields.get("title", ""),
content=fields.get("content", ""),
metadata=fields.get("metadata", {}),
entities=fields.get("kg_entities", {}),
relationships=fields.get("kg_relationships", {}),
terms=fields.get("kg_terms", {}),
deep_extraction=deep_extraction,
)
)
if len(current_batch) >= batch_size:
yield current_batch
current_batch = []
# Yield any remaining chunks
if current_batch:
yield current_batch
def _get_classification_content_from_call_chunks(
first_num_classification_chunks: list[dict],
kg_config_settings: KGConfigSettings,
) -> str:
"""
Creates a KGClassificationContent object from a list of call chunks.
"""
assert isinstance(kg_config_settings.KG_IGNORE_EMAIL_DOMAINS, list)
primary_owners = first_num_classification_chunks[0]["fields"].get(
"primary_owners", []
)
secondary_owners = first_num_classification_chunks[0]["fields"].get(
"secondary_owners", []
)
company_participant_emails = set()
account_participant_emails = set()
for owner in primary_owners + secondary_owners:
kg_owner = kg_email_processing(owner, kg_config_settings)
if any(
domain.lower() in kg_owner.company.lower()
for domain in kg_config_settings.KG_IGNORE_EMAIL_DOMAINS
):
continue
if kg_owner.employee:
company_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
else:
account_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
participant_string = "\n - " + "\n - ".join(company_participant_emails)
account_participant_string = "\n - " + "\n - ".join(account_participant_emails)
title_string = first_num_classification_chunks[0]["fields"]["title"]
content_string = "\n".join(
[
chunk_content["fields"]["content"]
for chunk_content in first_num_classification_chunks
]
)
classification_content = f"{title_string}\n\nVendor Participants:\n{participant_string}\n\n\
Other Participants:\n{account_participant_string}\n\nBeginning of Call:\n{content_string}"
return classification_content
def _get_classification_content_from_chunks(
first_num_classification_chunks: list[dict],
kg_config_settings: KGConfigSettings,
) -> str:
"""
Creates a KGClassificationContent object from a list of chunks.
"""
source_type = first_num_classification_chunks[0]["fields"]["source_type"]
if source_type.lower() in [call_type.value.lower() for call_type in OnyxCallTypes]:
classification_content = _get_classification_content_from_call_chunks(
first_num_classification_chunks,
kg_config_settings,
)
else:
classification_content = (
first_num_classification_chunks[0]["fields"]["title"]
+ "\n"
+ "\n".join(
[
chunk_content["fields"]["content"]
for chunk_content in first_num_classification_chunks
]
)
)
return classification_content

View File

@@ -40,6 +40,8 @@ from onyx.configs.app_configs import OAUTH_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from onyx.configs.app_configs import POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW
from onyx.configs.app_configs import POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE
from onyx.configs.app_configs import SYSTEM_RECURSION_LIMIT
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
@@ -208,6 +210,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
)
SqlEngine.get_engine()
SqlEngine.init_readonly_engine(
pool_size=POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE,
max_overflow=POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW,
)
verify_auth = fetch_versioned_implementation(
"onyx.auth.users", "verify_auth_setting"
)

File diff suppressed because it is too large Load Diff

View File

@@ -81,6 +81,11 @@ class SearchToolOverrideKwargs(BaseModel):
document_sources: list[DocumentSource] | None = None
time_cutoff: datetime | None = None
expanded_queries: QueryExpansions | None = None
kg_entities: list[str] | None = None
kg_relationships: list[str] | None = None
kg_terms: list[str] | None = None
kg_sources: list[str] | None = None
kg_chunk_id_zero_only: bool | None = False
class Config:
arbitrary_types_allowed = True

View File

@@ -299,6 +299,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
document_sources = None
time_cutoff = None
expanded_queries = None
kg_entities = None
kg_relationships = None
kg_terms = None
kg_sources = None
kg_chunk_id_zero_only = False
if override_kwargs:
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
alternate_db_session = override_kwargs.alternate_db_session
@@ -312,7 +317,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
document_sources = override_kwargs.document_sources
time_cutoff = override_kwargs.time_cutoff
expanded_queries = override_kwargs.expanded_queries
kg_entities = override_kwargs.kg_entities
kg_relationships = override_kwargs.kg_relationships
kg_terms = override_kwargs.kg_terms
kg_sources = override_kwargs.kg_sources
kg_chunk_id_zero_only = override_kwargs.kg_chunk_id_zero_only or False
# Fast path for ordering-only search
if ordering_only:
yield from self._run_ordering_only_search(
@@ -361,6 +370,19 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
# Overwrite time-cutoff should supercede existing time-cutoff, even if defined
retrieval_options.filters.time_cutoff = time_cutoff
retrieval_options = retrieval_options or RetrievalDetails()
retrieval_options.filters = retrieval_options.filters or BaseFilters()
if kg_entities:
retrieval_options.filters.kg_entities = kg_entities
if kg_relationships:
retrieval_options.filters.kg_relationships = kg_relationships
if kg_terms:
retrieval_options.filters.kg_terms = kg_terms
if kg_sources:
retrieval_options.filters.kg_sources = kg_sources
if kg_chunk_id_zero_only:
retrieval_options.filters.kg_chunk_id_zero_only = kg_chunk_id_zero_only
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=query,

View File

@@ -80,6 +80,7 @@ slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.15
starlette==0.46.1
supervisor==4.2.5
thefuzz==0.22.1
tiktoken==0.7.0
timeago==1.0.16
transformers==4.49.0

View File

@@ -155,7 +155,7 @@
"Email: fiannellib46@marriott.com\nIsDeleted: false\nLastName: Iannelli\nIsEmailBounced: false\nFirstName: Felicio\nIsPriorityRecord: false\nCleanStatus: Pending"
],
"semantic_identifier": "Voonder",
"metadata": {},
"metadata": {"object_type": "Account"},
"primary_owners": {"email": "hagen@danswer.ai", "first_name": "Hagen", "last_name": "oneill"},
"secondary_owners": null,
"title": null

View File

@@ -51,15 +51,21 @@ def answer_instance(
"onyx.chat.answer.fast_gpu_status_request",
return_value=True,
)
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config, mocker)
def _answer_fixture_impl(
mock_llm: LLM,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
mocker: MockerFixture,
rerank_settings: RerankingDetails | None = None,
) -> Answer:
mock_db_session = Mock(spec=Session)
mock_query = Mock()
mock_query.all.return_value = []
mock_db_session.query.return_value = mock_query
return Answer(
prompt_builder=AnswerPromptBuilder(
user_message=default_build_user_message(
@@ -74,7 +80,7 @@ def _answer_fixture_impl(
raw_user_query=QUERY,
raw_user_uploaded_files=[],
),
db_session=Mock(spec=Session),
db_session=mock_db_session,
answer_style_config=answer_style_config,
llm=mock_llm,
fast_llm=mock_llm,
@@ -404,7 +410,11 @@ def test_no_slow_reranking(
)
)
answer_instance = _answer_fixture_impl(
mock_llm, answer_style_config, prompt_config, rerank_settings=rerank_settings
mock_llm,
answer_style_config,
prompt_config,
mocker,
rerank_settings=rerank_settings,
)
assert (

View File

@@ -43,6 +43,7 @@ def test_skip_gen_ai_answer_generation_flag(
"onyx.chat.answer.fast_gpu_status_request",
return_value=True,
)
question = config["question"]
skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"]
@@ -59,8 +60,14 @@ def test_skip_gen_ai_answer_generation_flag(
mock_llm.stream = Mock()
mock_llm.stream.return_value = [Mock()]
# Set up the mock database session
mock_db_session = Mock(spec=Session)
mock_query = Mock()
mock_db_session.query.return_value = mock_query
mock_query.all.return_value = [] # Return empty list for KGConfig query
answer = Answer(
db_session=Mock(spec=Session),
db_session=mock_db_session,
answer_style_config=answer_style_config,
llm=mock_llm,
fast_llm=mock_llm,

View File

@@ -747,11 +747,17 @@ def test_salesforce_sqlite() -> None:
sf_db.apply_schema()
_create_csv_with_example_data(sf_db)
_test_query(sf_db)
_test_upsert(sf_db)
_test_relationships(sf_db)
_test_account_with_children(sf_db)
_test_relationship_updates(sf_db)
_test_get_affected_parent_ids(sf_db)
sf_db.close()

View File

@@ -183,6 +183,8 @@ services:
- POSTGRES_HOST=relational_db
- POSTGRES_USER=${POSTGRES_USER:-}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
- DB_READONLY_USER=${DB_READONLY_USER:-}
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
- POSTGRES_DB=${POSTGRES_DB:-}
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
- POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-}
@@ -383,6 +385,8 @@ services:
environment:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
- DB_READONLY_USER=${DB_READONLY_USER:-}
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
ports:
- "5432:5432"
volumes:

View File

@@ -148,6 +148,8 @@ services:
- POSTGRES_HOST=relational_db
- POSTGRES_USER=${POSTGRES_USER:-}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
- DB_READONLY_USER=${DB_READONLY_USER:-}
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
- POSTGRES_DB=${POSTGRES_DB:-}
- VESPA_HOST=index
- REDIS_HOST=cache
@@ -329,6 +331,8 @@ services:
environment:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
- DB_READONLY_USER=${DB_READONLY_USER:-}
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
ports:
- "5432:5432"
volumes:

View File

@@ -166,6 +166,8 @@ services:
- POSTGRES_HOST=relational_db
- POSTGRES_USER=${POSTGRES_USER:-}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
- DB_READONLY_USER=${DB_READONLY_USER:-}
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
- POSTGRES_DB=${POSTGRES_DB:-}
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
- VESPA_HOST=index
@@ -356,6 +358,8 @@ services:
environment:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
- DB_READONLY_USER=${DB_READONLY_USER:-}
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
ports:
- "5432:5432"
volumes:

View File

@@ -153,6 +153,8 @@ services:
environment:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
- DB_READONLY_USER=${DB_READONLY_USER:-}
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
ports:
- "5432"
volumes:

View File

@@ -55,3 +55,10 @@ SESSION_EXPIRE_TIME_SECONDS=604800
# Default values here are what Postgres uses by default, feel free to change.
POSTGRES_USER=postgres
POSTGRES_PASSWORD=password
# Default values here for the read-only user for the knowledge graph and other future read-only purposes.
# Please change password!
DB_READONLY_USER=db_readonly_user
DB_READONLY_PASSWORD=password

View File

@@ -40,6 +40,16 @@ spec:
secretKeyRef:
name: onyx-secrets
key: postgres_password
- name: DB_READONLY_USER
valueFrom:
secretKeyRef:
name: onyx-secrets
key: DB_READONLY_user
- name: DB_READONLY_PASSWORD
valueFrom:
secretKeyRef:
name: onyx-secrets
key: DB_READONLY_password
args: ["-c", "max_connections=250"]
ports:
- containerPort: 5432

View File

@@ -532,7 +532,7 @@ export function AssistantEditor({
// if disable_retrieval is set, set num_chunks to 0
// to tell the backend to not fetch any documents
const numChunks = searchToolEnabled ? values.num_chunks || 10 : 0;
const numChunks = searchToolEnabled ? values.num_chunks || 25 : 0;
const starterMessages = values.starter_messages
.filter(
(message: { message: string }) => message.message.trim() !== ""