mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-25 01:22:45 +00:00
Compare commits
67 Commits
v3.0.4
...
KG_dev_1_e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92bb851e19 | ||
|
|
3951cf2f78 | ||
|
|
9949ffd9bc | ||
|
|
5574e7485f | ||
|
|
99cf8dca74 | ||
|
|
aed1d87d89 | ||
|
|
6bb9b17c1c | ||
|
|
aaed5716f9 | ||
|
|
e302bcafaf | ||
|
|
57d412c3a7 | ||
|
|
6fb0933cfb | ||
|
|
e265b4f747 | ||
|
|
6c49736024 | ||
|
|
1779b65185 | ||
|
|
148bce59d9 | ||
|
|
8cadb57df2 | ||
|
|
33a4b15ae2 | ||
|
|
49e1a4b782 | ||
|
|
1f95e2291d | ||
|
|
64cba0bdca | ||
|
|
daad2e9de0 | ||
|
|
69d8aae5f0 | ||
|
|
34630b1947 | ||
|
|
e65e192622 | ||
|
|
713cc6b6a4 | ||
|
|
84de01d726 | ||
|
|
df2bc953e8 | ||
|
|
cd4fd267e1 | ||
|
|
078fe358bb | ||
|
|
2b9a8baf7a | ||
|
|
dbed47b4b0 | ||
|
|
ffc0692f5f | ||
|
|
0c5db00673 | ||
|
|
b4641189a0 | ||
|
|
c01136d661 | ||
|
|
034224a946 | ||
|
|
a03c2ed4ff | ||
|
|
ad966c51b4 | ||
|
|
d68bd41e72 | ||
|
|
8e2417f563 | ||
|
|
498fc587ab | ||
|
|
470a4b88ae | ||
|
|
0906125af0 | ||
|
|
38733965ba | ||
|
|
acf8a57798 | ||
|
|
462592867c | ||
|
|
0980aa4222 | ||
|
|
a8b748066e | ||
|
|
81833c8a54 | ||
|
|
cd1b48acd4 | ||
|
|
1552c61dbd | ||
|
|
5915066558 | ||
|
|
3aa2b51ca4 | ||
|
|
7bfb1a4a61 | ||
|
|
dfa12e9caf | ||
|
|
c595df7a1e | ||
|
|
1b08b35262 | ||
|
|
fe919d54da | ||
|
|
8259773174 | ||
|
|
b1406d6b65 | ||
|
|
95f5769fb6 | ||
|
|
8ed105bc05 | ||
|
|
b818842a94 | ||
|
|
b72ec48b41 | ||
|
|
45862c941b | ||
|
|
b04925816c | ||
|
|
c14468413a |
4
.github/workflows/pr-integration-tests.yml
vendored
4
.github/workflows/pr-integration-tests.yml
vendored
@@ -114,6 +114,8 @@ jobs:
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
@@ -212,6 +214,8 @@ jobs:
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
|
||||
@@ -152,6 +152,8 @@ jobs:
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
|
||||
@@ -0,0 +1,438 @@
|
||||
"""create knowlege graph tables
|
||||
|
||||
Revision ID: 495cb26ce93e
|
||||
Revises: a7688ab35c45
|
||||
Create Date: 2025-03-19 08:51:14.341989
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "495cb26ce93e"
|
||||
down_revision = "a7688ab35c45"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
# Create a new permission-less user to be later used for knowledge graph queries.
|
||||
# The user will later get temporary read priviledges for a specific view that will be
|
||||
# ad hoc generated specific to a knowledge graph query.
|
||||
#
|
||||
# Note: in order for the migration to run, the DB_READONLY_USER and DB_READONLY_PASSWORD
|
||||
# environment variables MUST be set. Otherwise, an exception will be raised.
|
||||
|
||||
if not MULTI_TENANT:
|
||||
# Create read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is created in the alembic_tenants migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- Explicitly revoke all privileges including CONNECT
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"kg_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("kg_variable_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("kg_variable_values", postgresql.ARRAY(sa.String()), nullable=False),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"kg_entity_type",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column("grounding", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"attributes",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("occurrences", sa.Integer(), nullable=True),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column("deep_extraction", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.Column("grounded_source_name", sa.String(), nullable=True),
|
||||
sa.Column("entity_values", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column(
|
||||
"clustering",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
)
|
||||
|
||||
# Create KGRelationshipType table
|
||||
op.create_table(
|
||||
"kg_relationship_type",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column(
|
||||
"source_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column(
|
||||
"target_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column("occurrences", sa.Integer(), nullable=True),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.Column(
|
||||
"clustering",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
)
|
||||
|
||||
# Create KGRelationshipTypeExtractionStaging table
|
||||
op.create_table(
|
||||
"kg_relationship_type_extraction_staging",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column(
|
||||
"source_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column(
|
||||
"target_entity_type_id_name", sa.String(), nullable=False, index=True
|
||||
),
|
||||
sa.Column("definition", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column("occurrences", sa.Integer(), nullable=True),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("active", sa.Boolean(), nullable=False, default=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.Column(
|
||||
"clustering",
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["target_entity_type_id_name"], ["kg_entity_type.id_name"]
|
||||
),
|
||||
)
|
||||
|
||||
# Create KGEntity table
|
||||
op.create_table(
|
||||
"kg_entity",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("sub_type", sa.String(), nullable=True, index=True),
|
||||
sa.Column("document_id", sa.String(), nullable=True, index=True),
|
||||
sa.Column(
|
||||
"alternative_names",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"keywords",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("occurrences", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
|
||||
),
|
||||
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("attributes", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["document_id"], ["document.id"]),
|
||||
)
|
||||
op.create_index("ix_entity_type_acl", "kg_entity", ["entity_type_id_name", "acl"])
|
||||
op.create_index(
|
||||
"ix_entity_name_search", "kg_entity", ["name", "entity_type_id_name"]
|
||||
)
|
||||
|
||||
# Create KGEntityExtractionStaging table
|
||||
op.create_table(
|
||||
"kg_entity_extraction_staging",
|
||||
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column("name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("sub_type", sa.String(), nullable=True, index=True),
|
||||
sa.Column("document_id", sa.String(), nullable=True, index=True),
|
||||
sa.Column(
|
||||
"alternative_names",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"keywords",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column("occurrences", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}"
|
||||
),
|
||||
sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("attributes", postgresql.JSONB, nullable=False, server_default="{}"),
|
||||
sa.Column("event_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["document_id"], ["document.id"]),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_entity_extraction_staging_acl",
|
||||
"kg_entity_extraction_staging",
|
||||
["entity_type_id_name", "acl"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_entity_extraction_staging_name_search",
|
||||
"kg_entity_extraction_staging",
|
||||
["name", "entity_type_id_name"],
|
||||
)
|
||||
|
||||
# Create KGRelationship table
|
||||
op.create_table(
|
||||
"kg_relationship",
|
||||
sa.Column("id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_document", sa.String(), nullable=True, index=True),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("occurrences", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(["source_node"], ["kg_entity.id_name"]),
|
||||
sa.ForeignKeyConstraint(["target_node"], ["kg_entity.id_name"]),
|
||||
sa.ForeignKeyConstraint(["source_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["target_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["source_document"], ["document.id"]),
|
||||
sa.ForeignKeyConstraint(
|
||||
["relationship_type_id_name"], ["kg_relationship_type.id_name"]
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_source_target_type",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id_name", "source_document"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kg_relationship_nodes", "kg_relationship", ["source_node", "target_node"]
|
||||
)
|
||||
|
||||
# Create KGRelationshipExtractionStaging table
|
||||
op.create_table(
|
||||
"kg_relationship_extraction_staging",
|
||||
sa.Column("id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("target_node_type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("source_document", sa.String(), nullable=True, index=True),
|
||||
sa.Column("type", sa.String(), nullable=False, index=True),
|
||||
sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True),
|
||||
sa.Column("occurrences", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_node"], ["kg_entity_extraction_staging.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["target_node"], ["kg_entity_extraction_staging.id_name"]
|
||||
),
|
||||
sa.ForeignKeyConstraint(["source_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["target_node_type"], ["kg_entity_type.id_name"]),
|
||||
sa.ForeignKeyConstraint(["source_document"], ["document.id"]),
|
||||
sa.ForeignKeyConstraint(
|
||||
["relationship_type_id_name"],
|
||||
["kg_relationship_type_extraction_staging.id_name"],
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_extraction_staging_source_target_type",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id_name", "source_document"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_kg_relationship_extraction_staging_nodes",
|
||||
"kg_relationship_extraction_staging",
|
||||
["source_node", "target_node"],
|
||||
)
|
||||
|
||||
# Create KGTerm table
|
||||
op.create_table(
|
||||
"kg_term",
|
||||
sa.Column("id_term", sa.String(), primary_key=True, nullable=False, index=True),
|
||||
sa.Column(
|
||||
"entity_types",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
),
|
||||
sa.Column(
|
||||
"time_created", sa.DateTime(timezone=True), server_default=sa.text("now()")
|
||||
),
|
||||
)
|
||||
op.create_index("ix_search_term_entities", "kg_term", ["entity_types"])
|
||||
op.create_index("ix_search_term_term", "kg_term", ["id_term"])
|
||||
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("kg_stage", sa.String(), nullable=True, index=True),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("kg_processing_time", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"connector",
|
||||
sa.Column(
|
||||
"kg_processing_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop tables in reverse order of creation to handle dependencies
|
||||
op.drop_table("kg_term")
|
||||
op.drop_table("kg_relationship")
|
||||
op.drop_table("kg_entity")
|
||||
op.drop_table("kg_relationship_type")
|
||||
op.drop_table("kg_relationship_extraction_staging")
|
||||
op.drop_table("kg_relationship_type_extraction_staging")
|
||||
op.drop_table("kg_entity_extraction_staging")
|
||||
op.drop_table("kg_entity_type")
|
||||
op.drop_column("connector", "kg_processing_enabled")
|
||||
op.drop_column("document", "kg_stage")
|
||||
op.drop_column("document", "kg_processing_time")
|
||||
op.drop_table("kg_config")
|
||||
|
||||
if not MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,72 @@
|
||||
"""add_db_readonly_user
|
||||
|
||||
Revision ID: 3b9f09038764
|
||||
Revises: 3b45e0018bf1
|
||||
Create Date: 2025-05-11 11:05:11.436977
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from alembic import op
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3b9f09038764"
|
||||
down_revision = "3b45e0018bf1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
|
||||
# the user is created in the standard migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- Explicitly revoke all privileges including CONNECT
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -35,9 +35,8 @@ def research_object_source(
|
||||
datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
graph_config.inputs.search_request.query
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = graph_config.inputs.search_request.query
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
object, document_source = state.object_source_combination
|
||||
|
||||
if search_tool is None or graph_config.inputs.search_request.persona is None:
|
||||
@@ -153,7 +152,6 @@ def research_object_source(
|
||||
),
|
||||
)
|
||||
]
|
||||
# fast_llm = graph_config.tooling.fast_llm
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
llm = primary_llm
|
||||
# Grader
|
||||
|
||||
@@ -73,7 +73,6 @@ def consolidate_object_research(
|
||||
)
|
||||
]
|
||||
graph_config.tooling.primary_llm
|
||||
# fast_llm = graph_config.tooling.fast_llm
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
llm = primary_llm
|
||||
# Grader
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
|
||||
|
||||
|
||||
class KGAnalysisPath(str, Enum):
|
||||
PROCESS_KG_ONLY_ANSWERS = "process_kg_only_answers"
|
||||
CONSTRUCT_DEEP_SEARCH_FILTERS = "construct_deep_search_filters"
|
||||
|
||||
|
||||
def simple_vs_search(
|
||||
state: MainState,
|
||||
) -> str:
|
||||
|
||||
if (
|
||||
state.strategy == KGAnswerStrategy.DEEP
|
||||
or state.search_type == KGSearchType.SEARCH
|
||||
):
|
||||
return KGAnalysisPath.CONSTRUCT_DEEP_SEARCH_FILTERS.value
|
||||
else:
|
||||
return KGAnalysisPath.PROCESS_KG_ONLY_ANSWERS.value
|
||||
|
||||
|
||||
def research_individual_object(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable] | str:
|
||||
edge_start_time = datetime.now()
|
||||
|
||||
assert state.div_con_entities is not None
|
||||
assert state.broken_down_question is not None
|
||||
assert state.vespa_filter_results is not None
|
||||
|
||||
if (
|
||||
state.search_type == KGSearchType.SQL
|
||||
and state.strategy == KGAnswerStrategy.DEEP
|
||||
):
|
||||
|
||||
return [
|
||||
Send(
|
||||
"process_individual_deep_search",
|
||||
ResearchObjectInput(
|
||||
research_nr=research_nr + 1,
|
||||
entity=entity,
|
||||
broken_down_question=state.broken_down_question,
|
||||
vespa_filter_results=state.vespa_filter_results,
|
||||
source_division=state.source_division,
|
||||
source_entity_filters=state.source_filters,
|
||||
log_messages=[
|
||||
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
step_results=[],
|
||||
),
|
||||
)
|
||||
for research_nr, entity in enumerate(state.div_con_entities)
|
||||
]
|
||||
elif state.search_type == KGSearchType.SEARCH:
|
||||
return "filtered_search"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid combination of search type: {state.search_type} and strategy: {state.strategy}"
|
||||
)
|
||||
132
backend/onyx/agents/agent_search/kb_search/graph_builder.py
Normal file
132
backend/onyx/agents/agent_search/kb_search/graph_builder.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.kb_search.conditional_edges import (
|
||||
research_individual_object,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.conditional_edges import simple_vs_search
|
||||
from onyx.agents.agent_search.kb_search.nodes.a1_extract_ert import extract_ert
|
||||
from onyx.agents.agent_search.kb_search.nodes.a2_analyze import analyze
|
||||
from onyx.agents.agent_search.kb_search.nodes.a3_generate_simple_sql import (
|
||||
generate_simple_sql,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.b1_construct_deep_search_filters import (
|
||||
construct_deep_search_filters,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.b2p_process_individual_deep_search import (
|
||||
process_individual_deep_search,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.b2s_filtered_search import filtered_search
|
||||
from onyx.agents.agent_search.kb_search.nodes.b3_consolidate_individual_deep_search import (
|
||||
consolidate_individual_deep_search,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.c1_process_kg_only_answers import (
|
||||
process_kg_only_answers,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.nodes.d1_generate_answer import generate_answer
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def kb_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the knowledge graph search process.
|
||||
"""
|
||||
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
"extract_ert",
|
||||
extract_ert,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"generate_simple_sql",
|
||||
generate_simple_sql,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"filtered_search",
|
||||
filtered_search,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"analyze",
|
||||
analyze,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"generate_answer",
|
||||
generate_answer,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"construct_deep_search_filters",
|
||||
construct_deep_search_filters,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"process_individual_deep_search",
|
||||
process_individual_deep_search,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"consoldidate_individual_deep_search",
|
||||
consolidate_individual_deep_search,
|
||||
)
|
||||
|
||||
graph.add_node("process_kg_only_answers", process_kg_only_answers)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="extract_ert")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="extract_ert",
|
||||
end_key="analyze",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="analyze",
|
||||
end_key="generate_simple_sql",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges("generate_simple_sql", simple_vs_search)
|
||||
|
||||
graph.add_edge(start_key="process_kg_only_answers", end_key="generate_answer")
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="construct_deep_search_filters",
|
||||
path=research_individual_object,
|
||||
path_map=["process_individual_deep_search", "filtered_search"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="process_individual_deep_search",
|
||||
end_key="consoldidate_individual_deep_search",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="consoldidate_individual_deep_search", end_key="generate_answer"
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="filtered_search",
|
||||
end_key="generate_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
461
backend/onyx/agents/agent_search/kb_search/graph_utils.py
Normal file
461
backend/onyx/agents/agent_search/kb_search/graph_utils.py
Normal file
@@ -0,0 +1,461 @@
|
||||
import re
|
||||
from time import sleep
|
||||
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
|
||||
from onyx.agents.agent_search.kb_search.models import KGExpandedGraphObjects
|
||||
from onyx.agents.agent_search.kb_search.states import SubQuestionAnswerResults
|
||||
from onyx.agents.agent_search.kb_search.step_definitions import STEP_DESCRIPTIONS
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.document import get_kg_doc_info_for_entity_name
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_document_id_for_entity
|
||||
from onyx.db.entity_type import get_entity_types
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _check_entities_disconnected(
|
||||
current_entities: list[str], current_relationships: list[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if all entities in current_entities are disconnected via the given relationships.
|
||||
Relationships are in the format: source_entity__relationship_name__target_entity
|
||||
|
||||
Args:
|
||||
current_entities: List of entity IDs to check connectivity for
|
||||
current_relationships: List of relationships in format source__relationship__target
|
||||
|
||||
Returns:
|
||||
bool: True if all entities are disconnected, False otherwise
|
||||
"""
|
||||
if not current_entities:
|
||||
return True
|
||||
|
||||
# Create a graph representation using adjacency list
|
||||
graph: dict[str, set[str]] = {entity: set() for entity in current_entities}
|
||||
|
||||
# Build the graph from relationships
|
||||
for relationship in current_relationships:
|
||||
try:
|
||||
source, _, target = relationship.split("__")
|
||||
if source in graph and target in graph:
|
||||
graph[source].add(target)
|
||||
# Add reverse edge to capture that we do also have a relationship in the other direction,
|
||||
# albeit not quite the same one.
|
||||
graph[target].add(source)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid relationship format: {relationship}")
|
||||
|
||||
# Use BFS to check if all entities are connected
|
||||
visited: set[str] = set()
|
||||
start_entity = current_entities[0]
|
||||
|
||||
def _bfs(start: str) -> None:
|
||||
queue = [start]
|
||||
visited.add(start)
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for neighbor in graph[current]:
|
||||
if neighbor not in visited:
|
||||
visited.add(neighbor)
|
||||
queue.append(neighbor)
|
||||
|
||||
# Start BFS from the first entity
|
||||
_bfs(start_entity)
|
||||
|
||||
logger.debug(f"Number of visited entities: {len(visited)}")
|
||||
|
||||
# Check if all current_entities are in visited
|
||||
return not all(entity in visited for entity in current_entities)
|
||||
|
||||
|
||||
def create_minimal_connected_query_graph(
|
||||
entities: list[str], relationships: list[str], max_depth: int = 1
|
||||
) -> KGExpandedGraphObjects:
|
||||
"""
|
||||
TODO: Implement this. For now we'll trust the SQL generation to do the right thing.
|
||||
Return the original entities and relationships.
|
||||
"""
|
||||
return KGExpandedGraphObjects(entities=entities, relationships=relationships)
|
||||
|
||||
|
||||
# def rename_entities_in_answer(answer: str) -> str:
|
||||
# """
|
||||
# Rename entities in the answer to be more readable by replacing entity references
|
||||
# with their semantic_id and link. This is case-insensitive and handles spaces between
|
||||
# entity type and ID. Trailing quotes are removed from entity names.
|
||||
# """
|
||||
# # Create a mapping of entity IDs to new names
|
||||
# entity_mapping = {}
|
||||
|
||||
|
||||
# with get_session_with_current_tenant() as db_session:
|
||||
# # Get all entity types
|
||||
# entity_types = get_entity_types(db_session)
|
||||
|
||||
# # For each entity type, find all entities in the answer
|
||||
# for entity_type in entity_types:
|
||||
# # Find all occurrences of <entity_type>:<entity_name> in the answer (case-insensitive)
|
||||
# # Pattern now handles spaces after the colon
|
||||
# pattern = f"{entity_type.id_name}:\\s*([^\\s,;.]+)"
|
||||
# matches = re.finditer(pattern, answer, re.IGNORECASE)
|
||||
|
||||
# for match in matches:
|
||||
# # Get the full match including any spaces
|
||||
# full_match = match.group(0)
|
||||
# # Get just the entity ID part (without spaces) and remove trailing quotes
|
||||
# entity_name = match.group(1).rstrip("\"'")
|
||||
# entity_id = f"{entity_type.id_name}:{entity_name}"
|
||||
|
||||
# if entity_id.lower() in entity_mapping:
|
||||
# continue
|
||||
|
||||
# # Get the document for this entity
|
||||
# entity = (
|
||||
# db_session.query(KGEntity)
|
||||
# .filter(
|
||||
# KGEntity.id_name.ilike(
|
||||
# entity_id
|
||||
# ) # Case-insensitive comparison
|
||||
# )
|
||||
# .first()
|
||||
# )
|
||||
|
||||
# if entity and entity.document_id:
|
||||
# # Get the document's semantic_id and link
|
||||
# document = (
|
||||
# db_session.query(Document)
|
||||
# .filter(Document.id == entity.document_id)
|
||||
# .first()
|
||||
# )
|
||||
|
||||
# if document:
|
||||
# # Create the replacement text with semantic_id and link
|
||||
# replacement = f"{document.semantic_id}"
|
||||
# if document.link:
|
||||
# replacement = f"[{replacement}]({document.link})"
|
||||
# entity_mapping[entity_id.lower()] = replacement
|
||||
# # Also map the full match (with spaces) to the same replacement
|
||||
# entity_mapping[full_match.lower()] = replacement
|
||||
|
||||
# # Replace all entity references in the answer (case-insensitive)
|
||||
# for entity_id, replacement in entity_mapping.items():
|
||||
# # Use regex for case-insensitive replacement
|
||||
# answer = re.sub(re.escape(entity_id), replacement, answer, flags=re.IGNORECASE)
|
||||
|
||||
# return answer
|
||||
|
||||
|
||||
def stream_write_step_description(
|
||||
writer: StreamWriter, step_nr: int, level: int = 0
|
||||
) -> None:
|
||||
|
||||
write_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=STEP_DESCRIPTIONS[step_nr].description,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
# Give the frontend a brief moment to catch up
|
||||
sleep(0.2)
|
||||
|
||||
|
||||
def stream_write_step_activities(
|
||||
writer: StreamWriter, step_nr: int, level: int = 0
|
||||
) -> None:
|
||||
for activity_nr, activity in enumerate(STEP_DESCRIPTIONS[step_nr].activities):
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=activity,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
query_id=activity_nr + 1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
|
||||
def stream_write_step_activity_explicit(
|
||||
writer: StreamWriter, step_nr: int, query_id: int, activity: str, level: int = 0
|
||||
) -> None:
|
||||
for activity in STEP_DESCRIPTIONS[step_nr].activities:
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=activity,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
query_id=query_id,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
|
||||
def stream_write_step_answer_explicit(
|
||||
writer: StreamWriter, step_nr: int, answer: str, level: int = 0
|
||||
) -> None:
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=answer,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
|
||||
def stream_write_step_structure(writer: StreamWriter, level: int = 0) -> None:
|
||||
for step_nr, step_detail in STEP_DESCRIPTIONS.items():
|
||||
|
||||
write_custom_event(
|
||||
"decomp_qs",
|
||||
SubQuestionPiece(
|
||||
sub_question=step_detail.description,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
for step_nr in STEP_DESCRIPTIONS.keys():
|
||||
|
||||
write_custom_event(
|
||||
"stream_finished",
|
||||
StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=0,
|
||||
)
|
||||
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_close_step_answer(
|
||||
writer: StreamWriter, step_nr: int, level: int = 0
|
||||
) -> None:
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_ANSWER,
|
||||
level=level,
|
||||
level_question_num=step_nr,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_write_close_steps(writer: StreamWriter, level: int = 0) -> None:
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.SUB_QUESTIONS,
|
||||
level=level,
|
||||
)
|
||||
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_write_close_main_answer(writer: StreamWriter, level: int = 0) -> None:
|
||||
stop_event = StreamStopInfo(
|
||||
stop_reason=StreamStopReason.FINISHED,
|
||||
stream_type=StreamType.MAIN_ANSWER,
|
||||
level=level,
|
||||
level_question_num=0,
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
|
||||
def stream_write_main_answer_token(
|
||||
writer: StreamWriter, token: str, level: int = 0, level_question_num: int = 0
|
||||
) -> None:
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=token, # No need to add space as tokenizer handles this
|
||||
level=level,
|
||||
level_question_num=level_question_num,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
|
||||
def get_doc_information_for_entity(entity_id_name: str) -> KGEntityDocInfo:
|
||||
"""
|
||||
Get document information for an entity, including its semantic name and document details.
|
||||
"""
|
||||
if ":" not in entity_id_name:
|
||||
return KGEntityDocInfo(
|
||||
doc_id=None,
|
||||
doc_semantic_id=None,
|
||||
doc_link=None,
|
||||
semantic_entity_name=entity_id_name,
|
||||
semantic_linked_entity_name=entity_id_name,
|
||||
)
|
||||
|
||||
entity_type, entity_name = map(str.strip, entity_id_name.split(":", 1))
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entity_document_id = get_document_id_for_entity(db_session, entity_id_name)
|
||||
if entity_document_id:
|
||||
return get_kg_doc_info_for_entity_name(
|
||||
db_session, entity_document_id, entity_type
|
||||
)
|
||||
else:
|
||||
return KGEntityDocInfo(
|
||||
doc_id=None,
|
||||
doc_semantic_id=None,
|
||||
doc_link=None,
|
||||
semantic_entity_name=entity_id_name,
|
||||
semantic_linked_entity_name=entity_id_name,
|
||||
)
|
||||
|
||||
|
||||
def rename_entities_in_answer(answer: str) -> str:
|
||||
"""
|
||||
Process entity references in the answer string by:
|
||||
1. Extracting all strings matching <str>:<str> or <str>: <str> patterns
|
||||
2. Looking up these references in the entity table
|
||||
3. Replacing valid references with their corresponding values
|
||||
|
||||
Args:
|
||||
answer: The input string containing potential entity references
|
||||
|
||||
Returns:
|
||||
str: The processed string with entity references replaced
|
||||
"""
|
||||
# Extract all entity references using regex
|
||||
# Pattern matches both <str>:<str> and <str>: <str> formats
|
||||
pattern = r"([^:\s]+):\s*([^\s,;.]+)"
|
||||
matches = re.finditer(pattern, answer)
|
||||
|
||||
# get active entity types
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
active_entity_types = [
|
||||
x.id_name for x in get_entity_types(db_session, active=True)
|
||||
]
|
||||
|
||||
# Collect extracted references
|
||||
entity_refs = [match.group(0).strip(":") for match in matches]
|
||||
|
||||
# Create dictionary for processed references
|
||||
processed_refs = {}
|
||||
|
||||
for entity_ref in entity_refs:
|
||||
entity_ref_split = entity_ref.split(":")
|
||||
if len(entity_ref_split) != 2:
|
||||
logger.warning(
|
||||
f"Invalid entity reference - number of colons is not 2 but {len(entity_ref_split)}"
|
||||
)
|
||||
continue
|
||||
entity_type, entity_name = entity_ref.split(":")
|
||||
entity_type = entity_type.upper().strip()
|
||||
if entity_type not in active_entity_types:
|
||||
continue
|
||||
entity_name = entity_name.capitalize().strip()
|
||||
potential_entity_id_name = f"{entity_type}:{entity_name}"
|
||||
|
||||
replacement_candidate = get_doc_information_for_entity(potential_entity_id_name)
|
||||
|
||||
if replacement_candidate.doc_id:
|
||||
processed_refs[entity_ref] = (
|
||||
replacement_candidate.semantic_linked_entity_name
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
# Replace all references in the answer
|
||||
for ref, replacement in processed_refs.items():
|
||||
answer = answer.replace(ref, replacement)
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
def build_document_context(
|
||||
document: InferenceSection | LlmDoc, document_number: int
|
||||
) -> str:
|
||||
"""
|
||||
Build a context string for a document.
|
||||
"""
|
||||
|
||||
metadata_list: list[str] = []
|
||||
document_content: str | None = None
|
||||
info_source: InferenceChunk | LlmDoc | None = None
|
||||
info_content: str | None = None
|
||||
|
||||
if isinstance(document, InferenceSection):
|
||||
info_source = document.center_chunk
|
||||
info_content = document.combined_content
|
||||
elif isinstance(document, LlmDoc):
|
||||
info_source = document
|
||||
info_content = document.content
|
||||
|
||||
for key, value in info_source.metadata.items():
|
||||
metadata_list.append(f" - {key}: {value}")
|
||||
|
||||
if metadata_list:
|
||||
metadata_str = "- Document Metadata:\n" + "\n".join(metadata_list)
|
||||
else:
|
||||
metadata_str = ""
|
||||
|
||||
# Construct document header with number and semantic identifier
|
||||
doc_header = f"Document {str(document_number)}: {info_source.semantic_identifier}"
|
||||
|
||||
# Combine all parts with proper spacing
|
||||
document_content = f"{doc_header}\n\n{metadata_str}\n\n{info_content}"
|
||||
|
||||
return document_content
|
||||
|
||||
|
||||
def get_near_empty_step_results(
|
||||
step_number: int,
|
||||
step_answer: str,
|
||||
verified_reranked_documents: list[InferenceSection] = [],
|
||||
) -> SubQuestionAnswerResults:
|
||||
"""
|
||||
Get near-empty step results from a list of step results.
|
||||
"""
|
||||
return SubQuestionAnswerResults(
|
||||
question=STEP_DESCRIPTIONS[step_number].description,
|
||||
question_id="0_" + str(step_number),
|
||||
answer=step_answer,
|
||||
verified_high_quality=True,
|
||||
sub_query_retrieval_results=[],
|
||||
verified_reranked_documents=verified_reranked_documents,
|
||||
context_documents=[],
|
||||
cited_documents=[],
|
||||
sub_question_retrieval_stats=AgentChunkRetrievalStats(
|
||||
verified_count=None,
|
||||
verified_avg_scores=None,
|
||||
rejected_count=None,
|
||||
rejected_avg_scores=None,
|
||||
verified_doc_chunk_ids=[],
|
||||
dismissed_doc_chunk_ids=[],
|
||||
),
|
||||
)
|
||||
49
backend/onyx/agents/agent_search/kb_search/models.py
Normal file
49
backend/onyx/agents/agent_search/kb_search/models.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
from onyx.agents.agent_search.kb_search.states import YesNoEnum
|
||||
|
||||
|
||||
class KGQuestionEntityExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
terms: list[str]
|
||||
time_filter: str | None
|
||||
|
||||
|
||||
class KGAnswerApproach(BaseModel):
|
||||
search_type: KGSearchType
|
||||
search_strategy: KGAnswerStrategy
|
||||
format: KGAnswerFormat
|
||||
broken_down_question: str | None = None
|
||||
divide_and_conquer: YesNoEnum | None = None
|
||||
|
||||
|
||||
class KGQuestionRelationshipExtractionResult(BaseModel):
|
||||
relationships: list[str]
|
||||
|
||||
|
||||
class KGQuestionExtractionResult(BaseModel):
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
terms: list[str]
|
||||
time_filter: str | None
|
||||
|
||||
|
||||
class KGExpandedGraphObjects(BaseModel):
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
|
||||
|
||||
class KGSteps(BaseModel):
|
||||
description: str
|
||||
activities: list[str]
|
||||
|
||||
|
||||
class KGEntityDocInfo(BaseModel):
|
||||
doc_id: str | None
|
||||
doc_semantic_id: str | None
|
||||
doc_link: str | None
|
||||
semantic_entity_name: str
|
||||
semantic_linked_entity_name: str
|
||||
@@ -0,0 +1,261 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_structure
|
||||
from onyx.agents.agent_search.kb_search.models import KGQuestionEntityExtractionResult
|
||||
from onyx.agents.agent_search.kb_search.models import (
|
||||
KGQuestionRelationshipExtractionResult,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import ERTExtractionUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.relationships import get_allowed_relationship_type_pairs
|
||||
from onyx.kg.extractions.extraction_processing import get_entity_types_str
|
||||
from onyx.kg.extractions.extraction_processing import get_relationship_types_str
|
||||
from onyx.prompts.kg_prompts import QUERY_ENTITY_EXTRACTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import QUERY_RELATIONSHIP_EXTRACTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def extract_ert(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ERTExtractionUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
# recheck KG enablement at outset KG graph
|
||||
|
||||
if not config["metadata"]["config"].behavior.kg_config_settings.KG_ENABLED:
|
||||
logger.error("KG approach is not enabled, the KG agent flow cannot run.")
|
||||
raise ValueError("KG approach is not enabled, the KG agent flow cannot run.")
|
||||
|
||||
_KG_STEP_NR = 1
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
|
||||
if graph_config.tooling.search_tool is None:
|
||||
raise ValueError("Search tool is not set")
|
||||
elif graph_config.tooling.search_tool.user is None:
|
||||
raise ValueError("User is not set")
|
||||
else:
|
||||
user_email = graph_config.tooling.search_tool.user.email
|
||||
user_name = user_email.split("@")[0] or "unknown"
|
||||
|
||||
# first four lines duplicates from generate_initial_answer
|
||||
question = graph_config.inputs.search_request.query
|
||||
today_date = datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
all_entity_types = get_entity_types_str(active=True)
|
||||
all_relationship_types = get_relationship_types_str(active=True)
|
||||
|
||||
# Stream structure of substeps out to the UI
|
||||
stream_write_step_structure(writer)
|
||||
|
||||
# Now specify core activities in the step (step 1)
|
||||
stream_write_step_activities(writer, _KG_STEP_NR)
|
||||
|
||||
### get the entities, terms, and filters
|
||||
|
||||
query_extraction_pre_prompt = QUERY_ENTITY_EXTRACTION_PROMPT.format(
|
||||
entity_types=all_entity_types,
|
||||
relationship_types=all_relationship_types,
|
||||
)
|
||||
|
||||
query_extraction_prompt = (
|
||||
query_extraction_pre_prompt.replace("---content---", question)
|
||||
.replace("---today_date---", today_date)
|
||||
.replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
.replace("{{", "{")
|
||||
.replace("}}", "}")
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=query_extraction_prompt,
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("{{", "{")
|
||||
.replace("}}", "}")
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
entity_extraction_result = (
|
||||
KGQuestionEntityExtractionResult.model_validate_json(cleaned_response)
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
entity_extraction_result = KGQuestionEntityExtractionResult(
|
||||
entities=[],
|
||||
terms=[],
|
||||
time_filter="",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
entity_extraction_result = KGQuestionEntityExtractionResult(
|
||||
entities=[],
|
||||
terms=[],
|
||||
time_filter="",
|
||||
)
|
||||
|
||||
# remove the attribute filters from the entities to for the purpose of the relationship
|
||||
entities_no_attributes = [
|
||||
entity.split("--")[0] for entity in entity_extraction_result.entities
|
||||
]
|
||||
entities_string_for_relationships = f"Entities: {entities_no_attributes}\n"
|
||||
ert_entities_string = f"Entities: {entities_string_for_relationships}\n"
|
||||
|
||||
### get the relationships
|
||||
|
||||
# find the relationship types that match the extracted entity types
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
allowed_relationship_pairs = get_allowed_relationship_type_pairs(
|
||||
db_session, entity_extraction_result.entities
|
||||
)
|
||||
|
||||
query_relationship_extraction_prompt = (
|
||||
QUERY_RELATIONSHIP_EXTRACTION_PROMPT.replace("---question---", question)
|
||||
.replace("---today_date---", today_date)
|
||||
.replace(
|
||||
"---relationship_type_options---",
|
||||
" - " + "\n - ".join(allowed_relationship_pairs),
|
||||
)
|
||||
.replace("---identified_entities---", ert_entities_string)
|
||||
.replace("---entity_types---", all_entity_types)
|
||||
.replace("{{", "{")
|
||||
.replace("}}", "}")
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=query_relationship_extraction_prompt,
|
||||
)
|
||||
]
|
||||
fast_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("{{", "{")
|
||||
.replace("}}", "}")
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
cleaned_response = cleaned_response.replace(" ", "")
|
||||
cleaned_response = cleaned_response.replace("{{", '{"')
|
||||
cleaned_response = cleaned_response.replace("}}", '"}')
|
||||
|
||||
try:
|
||||
relationship_extraction_result = (
|
||||
KGQuestionRelationshipExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
|
||||
relationships=[],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
relationship_extraction_result = KGQuestionRelationshipExtractionResult(
|
||||
relationships=[],
|
||||
)
|
||||
|
||||
## STEP 1
|
||||
# Stream answer pieces out to the UI for Step 1
|
||||
|
||||
extracted_entity_string = " \n ".join(
|
||||
[x.split("--")[0] for x in entity_extraction_result.entities]
|
||||
)
|
||||
extracted_relationship_string = " \n ".join(
|
||||
relationship_extraction_result.relationships
|
||||
)
|
||||
|
||||
step_answer = f"""Entities and relationships have been extracted from query - \n \
|
||||
Entities: {extracted_entity_string} - \n Relationships: {extracted_relationship_string}"""
|
||||
|
||||
stream_write_step_answer_explicit(writer, step_nr=1, answer=step_answer)
|
||||
|
||||
# Finish Step 1
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
return ERTExtractionUpdate(
|
||||
entities_types_str=all_entity_types,
|
||||
relationship_types_str=all_relationship_types,
|
||||
extracted_entities_w_attributes=entity_extraction_result.entities,
|
||||
extracted_entities_no_attributes=entities_no_attributes,
|
||||
extracted_relationships=relationship_extraction_result.relationships,
|
||||
extracted_terms=entity_extraction_result.terms,
|
||||
time_filter=entity_extraction_result.time_filter,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="extract entities terms",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR,
|
||||
step_answer=step_answer,
|
||||
verified_reranked_documents=[],
|
||||
)
|
||||
],
|
||||
)
|
||||
274
backend/onyx/agents/agent_search/kb_search/nodes/a2_analyze.py
Normal file
274
backend/onyx/agents/agent_search/kb_search/nodes/a2_analyze.py
Normal file
@@ -0,0 +1,274 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
create_minimal_connected_query_graph,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.models import KGAnswerApproach
|
||||
from onyx.agents.agent_search.kb_search.states import AnalysisUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerFormat
|
||||
from onyx.agents.agent_search.kb_search.states import KGAnswerStrategy
|
||||
from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import YesNoEnum
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_document_id_for_entity
|
||||
from onyx.kg.clustering.normalizations import normalize_entities
|
||||
from onyx.kg.clustering.normalizations import normalize_entities_w_attributes_from_map
|
||||
from onyx.kg.clustering.normalizations import normalize_relationships
|
||||
from onyx.kg.clustering.normalizations import normalize_terms
|
||||
from onyx.prompts.kg_prompts import STRATEGY_GENERATION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_fully_connected_entities(
|
||||
entities: list[str], relationships: list[str]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Analyze the connectedness of the entities and relationships.
|
||||
"""
|
||||
# Build a dictionary to track connections for each entity
|
||||
entity_connections: dict[str, set[str]] = {entity: set() for entity in entities}
|
||||
|
||||
# Parse relationships to build connection graph
|
||||
for relationship in relationships:
|
||||
# Split relationship into parts. Test for proper formatting just in case.
|
||||
# Should never be an error though at this point.
|
||||
parts = relationship.split("__")
|
||||
if len(parts) != 3:
|
||||
raise ValueError(f"Invalid relationship: {relationship}")
|
||||
|
||||
entity1 = parts[0]
|
||||
entity2 = parts[2]
|
||||
|
||||
# Add bidirectional connections
|
||||
if entity1 in entity_connections:
|
||||
entity_connections[entity1].add(entity2)
|
||||
if entity2 in entity_connections:
|
||||
entity_connections[entity2].add(entity1)
|
||||
|
||||
# Find entities connected to all others
|
||||
fully_connected_entities = []
|
||||
all_entities = set(entities)
|
||||
|
||||
for entity, connections in entity_connections.items():
|
||||
# Check if this entity is connected to all other entities
|
||||
if connections == all_entities - {entity}:
|
||||
fully_connected_entities.append(entity)
|
||||
|
||||
return fully_connected_entities
|
||||
|
||||
|
||||
def _check_for_single_doc(
|
||||
normalized_entities: list[str],
|
||||
raw_entities: list[str],
|
||||
normalized_relationships: list[str],
|
||||
raw_relationships: list[str],
|
||||
normalized_time_filter: str | None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Check if the query is for a single document, like 'Summarize ticket ENG-2243K'.
|
||||
None is returned if the query is not for a single document.
|
||||
"""
|
||||
if (
|
||||
len(normalized_entities) == 1
|
||||
and len(raw_entities) == 1
|
||||
and len(normalized_relationships) == 0
|
||||
and len(raw_relationships) == 0
|
||||
and normalized_time_filter is None
|
||||
):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
single_doc_id = get_document_id_for_entity(
|
||||
db_session, normalized_entities[0]
|
||||
)
|
||||
else:
|
||||
single_doc_id = None
|
||||
return single_doc_id
|
||||
|
||||
|
||||
def analyze(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> AnalysisUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
_KG_STEP_NR = 2
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
entities = (
|
||||
state.extracted_entities_no_attributes
|
||||
) # attribute knowledge is not required for this step
|
||||
relationships = state.extracted_relationships
|
||||
terms = state.extracted_terms
|
||||
time_filter = state.time_filter
|
||||
|
||||
## STEP 2 - stream out goals
|
||||
|
||||
stream_write_step_activities(writer, _KG_STEP_NR)
|
||||
|
||||
# Continue with node
|
||||
|
||||
normalized_entities = normalize_entities(entities)
|
||||
|
||||
query_graph_entities_w_attributes = normalize_entities_w_attributes_from_map(
|
||||
state.extracted_entities_w_attributes,
|
||||
normalized_entities.entity_normalization_map,
|
||||
)
|
||||
|
||||
normalized_relationships = normalize_relationships(
|
||||
relationships, normalized_entities.entity_normalization_map
|
||||
)
|
||||
normalized_terms = normalize_terms(terms)
|
||||
normalized_time_filter = time_filter
|
||||
|
||||
# If single-doc inquiry, send to single-doc processing directly
|
||||
|
||||
single_doc_id = _check_for_single_doc(
|
||||
normalized_entities=normalized_entities.entities,
|
||||
raw_entities=entities,
|
||||
normalized_relationships=normalized_relationships.relationships,
|
||||
raw_relationships=relationships,
|
||||
normalized_time_filter=normalized_time_filter,
|
||||
)
|
||||
|
||||
# Expand the entities and relationships to make sure that entities are connected
|
||||
|
||||
graph_expansion = create_minimal_connected_query_graph(
|
||||
normalized_entities.entities,
|
||||
normalized_relationships.relationships,
|
||||
max_depth=2,
|
||||
)
|
||||
|
||||
query_graph_entities = graph_expansion.entities
|
||||
query_graph_relationships = graph_expansion.relationships
|
||||
|
||||
# Evaluate whether a search needs to be done after identifying all entities and relationships
|
||||
|
||||
strategy_generation_prompt = (
|
||||
STRATEGY_GENERATION_PROMPT.replace(
|
||||
"---entities---", "\n".join(query_graph_entities)
|
||||
)
|
||||
.replace("---relationships---", "\n".join(query_graph_relationships))
|
||||
.replace("---possible_entities---", state.entities_types_str)
|
||||
.replace("---possible_relationships---", state.relationship_types_str)
|
||||
.replace("---question---", question)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=strategy_generation_prompt,
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
20,
|
||||
# fast_llm.invoke,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=5,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
approach_extraction_result = KGAnswerApproach.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
search_type = approach_extraction_result.search_type
|
||||
search_strategy = approach_extraction_result.search_strategy
|
||||
output_format = approach_extraction_result.format
|
||||
broken_down_question = approach_extraction_result.broken_down_question
|
||||
divide_and_conquer = approach_extraction_result.divide_and_conquer
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
search_type = KGSearchType.SEARCH
|
||||
search_strategy = KGAnswerStrategy.DEEP
|
||||
output_format = KGAnswerFormat.TEXT
|
||||
broken_down_question = None
|
||||
divide_and_conquer = YesNoEnum.NO
|
||||
if search_strategy is None or output_format is None:
|
||||
raise ValueError(f"Invalid strategy: {cleaned_response}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in strategy generation: {e}")
|
||||
raise e
|
||||
|
||||
# Stream out relevant results
|
||||
|
||||
if single_doc_id:
|
||||
search_strategy = (
|
||||
KGAnswerStrategy.DEEP
|
||||
) # if a single doc is identified, we will want to look at the details.
|
||||
|
||||
step_answer = f"Strategy and format have been extracted from query. Strategy: {search_strategy.value}, \
|
||||
Format: {output_format.value}, Broken down question: {broken_down_question}"
|
||||
|
||||
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
|
||||
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
# End node
|
||||
|
||||
return AnalysisUpdate(
|
||||
normalized_core_entities=normalized_entities.entities,
|
||||
normalized_core_relationships=normalized_relationships.relationships,
|
||||
query_graph_entities_no_attributes=query_graph_entities,
|
||||
query_graph_entities_w_attributes=query_graph_entities_w_attributes,
|
||||
query_graph_relationships=query_graph_relationships,
|
||||
normalized_terms=normalized_terms.terms,
|
||||
normalized_time_filter=normalized_time_filter,
|
||||
strategy=search_strategy,
|
||||
broken_down_question=broken_down_question,
|
||||
output_format=output_format,
|
||||
divide_and_conquer=divide_and_conquer,
|
||||
single_doc_id=single_doc_id,
|
||||
search_type=search_type,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="analyze",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR,
|
||||
step_answer=step_answer,
|
||||
verified_reranked_documents=[],
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,394 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_step_activities
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import KGSearchType
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import SQLSimpleGenerationUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.engine import get_db_readonly_user_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.kg_temp_view import create_views
|
||||
from onyx.db.kg_temp_view import drop_views
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_CORRECTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import SIMPLE_SQL_PROMPT
|
||||
from onyx.prompts.kg_prompts import SOURCE_DETECTION_PROMPT
|
||||
from onyx.prompts.kg_prompts import SQL_AGGREGATION_REMOVAL_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _sql_is_aggregate_query(sql_statement: str) -> bool:
|
||||
return any(
|
||||
agg_func in sql_statement.upper()
|
||||
for agg_func in ["COUNT(", "MAX(", "MIN(", "AVG(", "SUM("]
|
||||
)
|
||||
|
||||
|
||||
def _remove_aggregation(sql_statement: str, llm: LLM) -> str:
|
||||
"""
|
||||
Remove aggregate functions from the SQL statement.
|
||||
"""
|
||||
|
||||
sql_aggregation_removal_prompt = SQL_AGGREGATION_REMOVAL_PROMPT.replace(
|
||||
"---sql_statement---", sql_statement
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=sql_aggregation_removal_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=25,
|
||||
max_tokens=800,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
sql_statement = cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in strategy generation: {e}")
|
||||
raise e
|
||||
|
||||
return sql_statement
|
||||
|
||||
|
||||
def _get_source_documents(sql_statement: str, llm: LLM) -> str | None:
|
||||
"""
|
||||
Remove aggregate functions from the SQL statement.
|
||||
"""
|
||||
|
||||
source_detection_prompt = SOURCE_DETECTION_PROMPT.replace(
|
||||
"---original_sql_statement---", sql_statement
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=source_detection_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
cleaned_response: str | None = None
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=25,
|
||||
max_tokens=1200,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
sql_statement = cleaned_response.split("<sql>")[1].strip()
|
||||
sql_statement = sql_statement.split("</sql>")[0].strip()
|
||||
|
||||
except Exception as e:
|
||||
if cleaned_response is not None:
|
||||
logger.error(
|
||||
f"Could not generate source documents SQL: {e}. Original model response: {cleaned_response}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"Could not generate source documents SQL: {e}")
|
||||
return None
|
||||
|
||||
return sql_statement
|
||||
|
||||
|
||||
def generate_simple_sql(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> SQLSimpleGenerationUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
_KG_STEP_NR = 3
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
entities_types_str = state.entities_types_str
|
||||
relationship_types_str = state.relationship_types_str
|
||||
|
||||
single_doc_id = state.single_doc_id
|
||||
state.search_type
|
||||
|
||||
## STEP 3 - articulate goals
|
||||
|
||||
stream_write_step_activities(writer, _KG_STEP_NR)
|
||||
|
||||
if graph_config.tooling.search_tool is None:
|
||||
raise ValueError("Search tool is not set")
|
||||
elif graph_config.tooling.search_tool.user is None:
|
||||
raise ValueError("User is not set")
|
||||
else:
|
||||
user_email = graph_config.tooling.search_tool.user.email
|
||||
user_name = user_email.split("@")[0]
|
||||
|
||||
if state.search_type == KGSearchType.SQL and single_doc_id:
|
||||
|
||||
# If single doc id already identified, we do not need to go through the KG
|
||||
# query cycle, saving a lot of time.
|
||||
|
||||
main_sql_statement: str | None = None
|
||||
query_results: list[dict[str, Any]] | None = None
|
||||
source_documents_sql: str | None = None
|
||||
source_document_results: list[str] | None = [single_doc_id]
|
||||
reasoning: str | None = (
|
||||
f"A KG query was not required as the source document was already identified: {single_doc_id}"
|
||||
)
|
||||
|
||||
step_answer = f"Source document already identified: {single_doc_id}"
|
||||
|
||||
elif state.search_type == KGSearchType.SEARCH:
|
||||
# If we do a filtered search, then we do not need to go through the SQL
|
||||
# generation process.
|
||||
|
||||
main_sql_statement = None
|
||||
query_results = None
|
||||
source_documents_sql = None
|
||||
source_document_results = None
|
||||
reasoning = "A KG query was not required as we will use a filtered search."
|
||||
|
||||
step_answer = "Filtered search will be used."
|
||||
|
||||
else:
|
||||
# If no single doc id already identified, we need to go through the KG
|
||||
# query cycle, including generating the SQL for the answer and the sources
|
||||
|
||||
# Build prompt
|
||||
|
||||
simple_sql_prompt = (
|
||||
SIMPLE_SQL_PROMPT.replace("---entity_types---", entities_types_str)
|
||||
.replace("---relationship_types---", relationship_types_str)
|
||||
.replace("---question---", question)
|
||||
.replace(
|
||||
"---query_entities_with_attributes---",
|
||||
"\n".join(state.query_graph_entities_w_attributes),
|
||||
)
|
||||
.replace(
|
||||
"---query_relationships---", "\n".join(state.query_graph_relationships)
|
||||
)
|
||||
.replace("---today_date---", datetime.now().strftime("%Y-%m-%d"))
|
||||
.replace("---user_name---", f"EMPLOYEE:{user_name}")
|
||||
)
|
||||
|
||||
logger.debug(f"simple_sql_prompt: {simple_sql_prompt}")
|
||||
|
||||
# Create temporary view
|
||||
|
||||
allowed_docs_view_name = f"allowed_docs_{user_email}".replace("@", "_").replace(
|
||||
".", "_"
|
||||
)
|
||||
kg_relationships_view_name = (
|
||||
f"kg_relationships_with_access_{user_email}".replace("@", "_").replace(
|
||||
".", "_"
|
||||
)
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_views(
|
||||
db_session,
|
||||
user_email=user_email,
|
||||
allowed_docs_view_name=allowed_docs_view_name,
|
||||
kg_relationships_view_name=kg_relationships_view_name,
|
||||
)
|
||||
|
||||
# prepare SQL query generation
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=simple_sql_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=25,
|
||||
max_tokens=1500,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
sql_statement = (
|
||||
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
|
||||
)
|
||||
sql_statement = sql_statement.split(";")[0].strip() + ";"
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
sql_statement = sql_statement.replace(
|
||||
"kg_relationship", kg_relationships_view_name
|
||||
)
|
||||
|
||||
reasoning = (
|
||||
cleaned_response.split("<reasoning>")[1]
|
||||
.strip()
|
||||
.split("</reasoning>")[0]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in strategy generation: {e}")
|
||||
raise e
|
||||
|
||||
# Correction if needed:
|
||||
|
||||
correction_prompt = SIMPLE_SQL_CORRECTION_PROMPT.replace(
|
||||
"---draft_sql---", sql_statement
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=correction_prompt,
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
primary_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=25,
|
||||
max_tokens=1500,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
|
||||
sql_statement = (
|
||||
cleaned_response.split("<sql>")[1].split("</sql>")[0].strip()
|
||||
)
|
||||
sql_statement = sql_statement.split(";")[0].strip() + ";"
|
||||
sql_statement = sql_statement.replace("sql", "").strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in generating the sql correction: {e}. Original model response: {cleaned_response}"
|
||||
)
|
||||
raise e
|
||||
|
||||
# Get SQL for source documents
|
||||
|
||||
source_documents_sql = _get_source_documents(sql_statement, llm=primary_llm)
|
||||
|
||||
logger.debug(f"source_documents_sql: {source_documents_sql}")
|
||||
|
||||
scalar_result = None
|
||||
query_results = None
|
||||
|
||||
with get_db_readonly_user_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
result = db_session.execute(text(sql_statement))
|
||||
# Handle scalar results (like COUNT)
|
||||
if sql_statement.upper().startswith("SELECT COUNT"):
|
||||
scalar_result = result.scalar()
|
||||
query_results = (
|
||||
[{"count": int(scalar_result)}]
|
||||
if scalar_result is not None
|
||||
else []
|
||||
)
|
||||
else:
|
||||
# Handle regular row results
|
||||
rows = result.fetchall()
|
||||
query_results = [dict(row._mapping) for row in rows]
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing SQL query: {e}")
|
||||
|
||||
raise e
|
||||
|
||||
source_document_results = None
|
||||
if source_documents_sql is not None and source_documents_sql != sql_statement:
|
||||
with get_db_readonly_user_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
result = db_session.execute(text(source_documents_sql))
|
||||
rows = result.fetchall()
|
||||
query_source_document_results = [dict(row._mapping) for row in rows]
|
||||
source_document_results = [
|
||||
source_document_result["source_document"]
|
||||
for source_document_result in query_source_document_results
|
||||
]
|
||||
except Exception as e:
|
||||
# No stopping here, the individualized SQL query is not mandatory
|
||||
logger.error(f"Error executing Individualized SQL query: {e}")
|
||||
# individualized_query_results = None
|
||||
|
||||
else:
|
||||
source_document_results = None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
drop_views(
|
||||
db_session,
|
||||
allowed_docs_view_name=allowed_docs_view_name,
|
||||
kg_relationships_view_name=kg_relationships_view_name,
|
||||
)
|
||||
|
||||
logger.info(f"query_results: {query_results}")
|
||||
logger.debug(f"sql_statement: {sql_statement}")
|
||||
|
||||
# Stream out reasoning and SQL query
|
||||
|
||||
step_answer = f"Reasoning: {reasoning} \n \n SQL Query: {sql_statement}"
|
||||
|
||||
main_sql_statement = sql_statement
|
||||
|
||||
if reasoning:
|
||||
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=reasoning)
|
||||
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
return SQLSimpleGenerationUpdate(
|
||||
sql_query=main_sql_statement,
|
||||
sql_query_results=query_results,
|
||||
individualized_sql_query=None,
|
||||
individualized_query_results=None,
|
||||
source_documents_sql=source_documents_sql,
|
||||
source_document_results=source_document_results or [],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="generate simple sql",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR,
|
||||
step_answer=step_answer,
|
||||
verified_reranked_documents=[],
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,178 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.states import DeepSearchFilterUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import KGFilterConstructionResults
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entity_type import get_entity_types_with_grounded_source_name
|
||||
from onyx.prompts.kg_prompts import SEARCH_FILTER_CONSTRUCTION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def construct_deep_search_filters(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter
|
||||
) -> DeepSearchFilterUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
|
||||
entities_types_str = state.entities_types_str
|
||||
entities = state.query_graph_entities_no_attributes
|
||||
relationships = state.query_graph_relationships
|
||||
simple_sql_query = state.sql_query
|
||||
simple_sql_results = state.sql_query_results
|
||||
source_document_results = state.source_document_results
|
||||
if simple_sql_results:
|
||||
simple_sql_results_str = "\n".join([str(x) for x in simple_sql_results])
|
||||
else:
|
||||
simple_sql_results_str = "(no SQL results generated)"
|
||||
if source_document_results:
|
||||
source_document_results_str = "\n".join(
|
||||
[str(x) for x in source_document_results]
|
||||
)
|
||||
else:
|
||||
source_document_results_str = "(no source document results generated)"
|
||||
|
||||
search_filter_construction_prompt = (
|
||||
SEARCH_FILTER_CONSTRUCTION_PROMPT.replace(
|
||||
"---entity_type_descriptions---",
|
||||
entities_types_str,
|
||||
)
|
||||
.replace(
|
||||
"---entity_filters---",
|
||||
"\n".join(entities),
|
||||
)
|
||||
.replace(
|
||||
"---relationship_filters---",
|
||||
"\n".join(relationships),
|
||||
)
|
||||
.replace(
|
||||
"---sql_query---",
|
||||
simple_sql_query or "(no SQL generated)",
|
||||
)
|
||||
.replace(
|
||||
"---sql_results---",
|
||||
simple_sql_results_str or "(no SQL results generated)",
|
||||
)
|
||||
.replace(
|
||||
"---source_document_results---",
|
||||
source_document_results_str or "(no source document results generated)",
|
||||
)
|
||||
.replace(
|
||||
"---question---",
|
||||
question,
|
||||
)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=search_filter_construction_prompt,
|
||||
)
|
||||
]
|
||||
llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
15,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=15,
|
||||
max_tokens=1400,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
|
||||
if last_bracket == -1 or first_bracket == -1:
|
||||
raise ValueError("No valid JSON found in LLM response - no brackets found")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
cleaned_response = cleaned_response.replace("{{", '{"')
|
||||
cleaned_response = cleaned_response.replace("}}", '"}')
|
||||
|
||||
try:
|
||||
|
||||
filter_results = KGFilterConstructionResults.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
filter_results = KGFilterConstructionResults(
|
||||
global_entity_filters=[],
|
||||
global_relationship_filters=[],
|
||||
local_entity_filters=[],
|
||||
source_document_filters=[],
|
||||
structure=[],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in extract_ert: {e}")
|
||||
filter_results = KGFilterConstructionResults(
|
||||
global_entity_filters=[],
|
||||
global_relationship_filters=[],
|
||||
local_entity_filters=[],
|
||||
source_document_filters=[],
|
||||
structure=[],
|
||||
)
|
||||
|
||||
div_con_structure = filter_results.structure
|
||||
|
||||
logger.info(f"div_con_structure: {div_con_structure}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
double_grounded_entity_types = get_entity_types_with_grounded_source_name(
|
||||
db_session
|
||||
)
|
||||
|
||||
source_division = False
|
||||
|
||||
if div_con_structure:
|
||||
for entity_type in double_grounded_entity_types:
|
||||
if entity_type.grounded_source_name.lower() in div_con_structure[0].lower():
|
||||
source_division = True
|
||||
break
|
||||
|
||||
return DeepSearchFilterUpdate(
|
||||
vespa_filter_results=filter_results,
|
||||
div_con_entities=div_con_structure,
|
||||
source_division=source_division,
|
||||
global_entity_filters=[
|
||||
f"{global_filter}:*"
|
||||
for global_filter in filter_results.global_entity_filters
|
||||
],
|
||||
global_relationship_filters=filter_results.global_relationship_filters,
|
||||
local_entity_filters=filter_results.local_entity_filters,
|
||||
source_filters=filter_results.source_document_filters,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="construct deep search filters",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[],
|
||||
)
|
||||
@@ -0,0 +1,180 @@
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
get_doc_information_for_entity,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import write_custom_event
|
||||
from onyx.agents.agent_search.kb_search.ops import research
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectInput
|
||||
from onyx.agents.agent_search.kb_search.states import ResearchObjectUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.kg_prompts import KG_OBJECT_SOURCE_RESEARCH_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def process_individual_deep_search(
|
||||
state: ResearchObjectInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ResearchObjectUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
_KG_STEP_NR = 4
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = state.broken_down_question
|
||||
source_division = state.source_division
|
||||
source_filters = state.source_entity_filters
|
||||
|
||||
object = state.entity.replace(":", ": ").lower()
|
||||
|
||||
object_id = object.split(":")[1].strip()
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError("search_tool is not provided")
|
||||
|
||||
research_nr = state.research_nr
|
||||
|
||||
extended_question = f"{question} in regards to {object}"
|
||||
if source_division:
|
||||
extended_question = question
|
||||
|
||||
raw_kg_entity_filters = copy.deepcopy(
|
||||
list(set((state.vespa_filter_results.global_entity_filters + [state.entity])))
|
||||
)
|
||||
|
||||
kg_entity_filters = []
|
||||
for raw_kg_entity_filter in raw_kg_entity_filters:
|
||||
if ":" not in raw_kg_entity_filter:
|
||||
raw_kg_entity_filter += ":*"
|
||||
kg_entity_filters.append(raw_kg_entity_filter)
|
||||
|
||||
kg_relationship_filters = copy.deepcopy(
|
||||
state.vespa_filter_results.global_relationship_filters
|
||||
)
|
||||
|
||||
# if this is a single-object analysis and object is a source,
|
||||
# drop the other filters as they already were included
|
||||
# in the KG query that led to this analysis
|
||||
|
||||
if source_division and source_filters:
|
||||
for source_filter in source_filters:
|
||||
if object_id.lower() == source_filter.lower():
|
||||
source_filters = [source_filter]
|
||||
kg_relationship_filters = []
|
||||
kg_entity_filters = []
|
||||
break
|
||||
|
||||
logger.info("Research for object: " + object)
|
||||
logger.info(f"kg_entity_filters: {kg_entity_filters}")
|
||||
logger.info(f"kg_relationship_filters: {kg_relationship_filters}")
|
||||
|
||||
# Step 4 - stream out the research query
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query=f"{get_doc_information_for_entity(object).semantic_entity_name}",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=research_nr + 1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
retrieved_docs = research(
|
||||
question=extended_question,
|
||||
kg_entities=kg_entity_filters,
|
||||
kg_relationships=kg_relationship_filters,
|
||||
kg_sources=source_filters,
|
||||
search_tool=search_tool,
|
||||
)
|
||||
|
||||
document_texts_list = []
|
||||
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs):
|
||||
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
|
||||
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
# Built prompt
|
||||
|
||||
datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
kg_object_source_research_prompt = KG_OBJECT_SOURCE_RESEARCH_PROMPT.format(
|
||||
question=extended_question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
|
||||
# Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=kg_object_source_research_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
llm = primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
30,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
object_research_results = str(llm_response.content).replace("```json\n", "")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in research_object_source: {e}")
|
||||
|
||||
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
|
||||
|
||||
return ResearchObjectUpdate(
|
||||
research_object_results=[
|
||||
{
|
||||
"object": object.replace(":", ": ").capitalize(),
|
||||
"results": object_research_results,
|
||||
}
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="process individual deep search",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[],
|
||||
)
|
||||
@@ -0,0 +1,176 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import write_custom_event
|
||||
from onyx.agents.agent_search.kb_search.ops import research
|
||||
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
get_answer_generation_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.kg_prompts import KG_SEARCH_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def filtered_search(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ConsolidatedResearchUpdate:
|
||||
"""
|
||||
LangGraph node to do a filtered search.
|
||||
"""
|
||||
_KG_STEP_NR = 4
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = graph_config.inputs.search_request.query
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError("search_tool is not provided")
|
||||
|
||||
if not state.vespa_filter_results:
|
||||
raise ValueError("vespa_filter_results is not provided")
|
||||
raw_kg_entity_filters = list(
|
||||
set((state.vespa_filter_results.global_entity_filters))
|
||||
)
|
||||
|
||||
kg_entity_filters = []
|
||||
for raw_kg_entity_filter in raw_kg_entity_filters:
|
||||
if ":" not in raw_kg_entity_filter:
|
||||
raw_kg_entity_filter += ":*"
|
||||
kg_entity_filters.append(raw_kg_entity_filter)
|
||||
|
||||
kg_relationship_filters = state.vespa_filter_results.global_relationship_filters
|
||||
|
||||
logger.info("Starting filtered search")
|
||||
logger.debug(f"kg_entity_filters: {kg_entity_filters}")
|
||||
logger.debug(f"kg_relationship_filters: {kg_relationship_filters}")
|
||||
|
||||
# Step 4 - stream out the research query
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query="Conduct a filtered search",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
retrieved_docs = cast(
|
||||
list[InferenceSection],
|
||||
research(
|
||||
question=question,
|
||||
kg_entities=kg_entity_filters,
|
||||
kg_relationships=kg_relationship_filters,
|
||||
kg_sources=None,
|
||||
search_tool=search_tool,
|
||||
inference_sections_only=True,
|
||||
),
|
||||
)
|
||||
|
||||
answer_generation_documents = get_answer_generation_documents(
|
||||
relevant_docs=retrieved_docs,
|
||||
context_documents=retrieved_docs,
|
||||
original_question_docs=retrieved_docs,
|
||||
max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
|
||||
)
|
||||
|
||||
document_texts_list = []
|
||||
|
||||
for doc_num, retrieved_doc in enumerate(
|
||||
answer_generation_documents.context_documents
|
||||
):
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
# Built prompt
|
||||
|
||||
datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
kg_object_source_research_prompt = KG_SEARCH_PROMPT.format(
|
||||
question=question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
|
||||
# Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=kg_object_source_research_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
llm = primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
30,
|
||||
llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
filtered_search_answer = str(llm_response.content).replace("```json\n", "")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in research_object_source: {e}")
|
||||
|
||||
step_answer = "Filtered search is complete."
|
||||
|
||||
stream_write_step_answer_explicit(
|
||||
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
|
||||
|
||||
return ConsolidatedResearchUpdate(
|
||||
consolidated_research_object_results_str=filtered_search_answer,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="generate simple sql",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR,
|
||||
step_answer=step_answer,
|
||||
verified_reranked_documents=retrieved_docs,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,70 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import ConsolidatedResearchUpdate
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def consolidate_individual_deep_search(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ConsolidatedResearchUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
_KG_STEP_NR = 4
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
graph_config.inputs.search_request.query
|
||||
state.entities_types_str
|
||||
|
||||
research_object_results = state.research_object_results
|
||||
|
||||
consolidated_research_object_results_str = "\n".join(
|
||||
[f"{x['object']}: {x['results']}" for x in research_object_results]
|
||||
)
|
||||
|
||||
consolidated_research_object_results_str = rename_entities_in_answer(
|
||||
consolidated_research_object_results_str
|
||||
)
|
||||
|
||||
step_answer = "All research is complete. Consolidating results..."
|
||||
|
||||
stream_write_step_answer_explicit(
|
||||
writer, answer=step_answer, level=0, step_nr=_KG_STEP_NR
|
||||
)
|
||||
|
||||
stream_close_step_answer(writer, level=0, step_nr=_KG_STEP_NR)
|
||||
|
||||
return ConsolidatedResearchUpdate(
|
||||
consolidated_research_object_results_str=consolidated_research_object_results_str,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="generate simple sql",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR, step_answer=step_answer
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,135 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import get_near_empty_step_results
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_close_step_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_step_answer_explicit,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.kb_search.states import ResultsDataUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.db.document import get_base_llm_doc_information
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _general_format(result: dict[str, Any]) -> str:
|
||||
name = result.get("name")
|
||||
entity_type_id_name: Any = result.get("entity_type_id_name")
|
||||
result.get("id_name")
|
||||
|
||||
assert entity_type_id_name is str
|
||||
return f"{entity_type_id_name.capitalize()}: {name}"
|
||||
|
||||
|
||||
def _get_formated_source_reference_results(
|
||||
source_document_results: list[str] | None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Generate reference results from the query results data string.
|
||||
"""
|
||||
|
||||
if source_document_results is None:
|
||||
return None
|
||||
|
||||
# get all entities that correspond to an Onyx document
|
||||
document_ids = source_document_results
|
||||
|
||||
with get_session_with_current_tenant() as session:
|
||||
llm_doc_information_results = get_base_llm_doc_information(
|
||||
session, document_ids
|
||||
)
|
||||
|
||||
if len(llm_doc_information_results) == 0:
|
||||
return ""
|
||||
|
||||
return (
|
||||
f"\n \n Here are {len(llm_doc_information_results)} examples: \n \n "
|
||||
+ " \n \n ".join(llm_doc_information_results)
|
||||
)
|
||||
|
||||
|
||||
def process_kg_only_answers(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ResultsDataUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
_KG_STEP_NR = 4
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
graph_config.inputs.search_request.query
|
||||
query_results = state.sql_query_results
|
||||
state.individualized_query_results
|
||||
source_document_results = state.source_document_results
|
||||
|
||||
# we use this stream write explicitly
|
||||
|
||||
write_custom_event(
|
||||
"subqueries",
|
||||
SubQueryPiece(
|
||||
sub_query="Formatted References",
|
||||
level=0,
|
||||
level_question_num=_KG_STEP_NR,
|
||||
query_id=1,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
query_results_list = []
|
||||
|
||||
if query_results:
|
||||
for query_result in query_results:
|
||||
query_results_list.append(str(query_result).replace(":", ": ").capitalize())
|
||||
else:
|
||||
raise ValueError("No query results were found")
|
||||
|
||||
query_results_data_str = "\n".join(query_results_list)
|
||||
|
||||
source_reference_result_str = _get_formated_source_reference_results(
|
||||
source_document_results
|
||||
)
|
||||
|
||||
## STEP 4 - same components as Step 1
|
||||
|
||||
step_answer = (
|
||||
"No further research is needed, the answer is derived from the knowledge graph."
|
||||
)
|
||||
|
||||
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=step_answer)
|
||||
|
||||
stream_close_step_answer(writer, _KG_STEP_NR)
|
||||
|
||||
return ResultsDataUpdate(
|
||||
query_results_data_str=query_results_data_str,
|
||||
individualized_query_results_data_str="",
|
||||
reference_results_str=source_reference_result_str,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="kg query results data processing",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
step_results=[
|
||||
get_near_empty_step_results(
|
||||
step_number=_KG_STEP_NR, step_answer=step_answer
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,284 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import rename_entities_in_answer
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_close_main_answer,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import stream_write_close_steps
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import (
|
||||
stream_write_main_answer_token,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.ops import research
|
||||
from onyx.agents.agent_search.kb_search.states import MainOutput
|
||||
from onyx.agents.agent_search.kb_search.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.calculations import (
|
||||
get_answer_generation_documents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import log_agent_sub_question_results
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_EXAMPLES_PROMPT
|
||||
from onyx.prompts.kg_prompts import OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT
|
||||
from onyx.tools.tool_implementations.search.search_tool import IndexFilters
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchQueryInfo
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchType
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_answer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> MainOutput:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.search_request.query
|
||||
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
if search_tool is None:
|
||||
raise ValueError("Search tool is not set")
|
||||
|
||||
# Close out previous streams of steps
|
||||
|
||||
# DECLARE STEPS DONE
|
||||
|
||||
stream_write_close_steps(writer)
|
||||
|
||||
## MAIN ANSWER
|
||||
|
||||
# identify whether documents have already been retrieved
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
for step_result in state.step_results:
|
||||
retrieved_docs += step_result.verified_reranked_documents
|
||||
|
||||
# if still needed, get a search done and send the results to the UI
|
||||
|
||||
if not retrieved_docs and state.source_document_results:
|
||||
assert graph_config.tooling.search_tool is not None
|
||||
retrieved_docs = cast(
|
||||
list[InferenceSection],
|
||||
research(
|
||||
question=question,
|
||||
kg_entities=[],
|
||||
kg_relationships=[],
|
||||
kg_sources=state.source_document_results[
|
||||
:KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
],
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
kg_chunk_id_zero_only=True,
|
||||
inference_sections_only=True,
|
||||
),
|
||||
)
|
||||
|
||||
answer_generation_documents = get_answer_generation_documents(
|
||||
relevant_docs=retrieved_docs,
|
||||
context_documents=retrieved_docs,
|
||||
original_question_docs=retrieved_docs,
|
||||
max_docs=KG_RESEARCH_NUM_RETRIEVED_DOCS,
|
||||
)
|
||||
|
||||
relevance_list = relevance_from_docs(
|
||||
answer_generation_documents.streaming_documents
|
||||
)
|
||||
|
||||
assert graph_config.tooling.search_tool is not None
|
||||
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=SearchQueryInfo(
|
||||
predicted_search=SearchType.KEYWORD,
|
||||
# acl here is empty, because the searach alrady happened and
|
||||
# we are streaming out the results.
|
||||
final_filters=IndexFilters(access_control_list=[]),
|
||||
recency_bias_multiplier=1.0,
|
||||
),
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
):
|
||||
write_custom_event(
|
||||
"tool_response",
|
||||
ExtendedToolResponse(
|
||||
id=tool_response.id,
|
||||
response=tool_response.response,
|
||||
level=0,
|
||||
level_question_num=0, # 0, 0 is the base question
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
# continue with the answer generation
|
||||
|
||||
output_format = (
|
||||
state.output_format.value
|
||||
if state.output_format
|
||||
else "<you be the judge how to best present the data>"
|
||||
)
|
||||
|
||||
# if deep path was taken:
|
||||
|
||||
consolidated_research_object_results_str = (
|
||||
state.consolidated_research_object_results_str
|
||||
)
|
||||
reference_results_str = (
|
||||
state.reference_results_str
|
||||
) # will not be part of LLM. Manually added to the answer
|
||||
|
||||
# if simple path was taken:
|
||||
introductory_answer = state.query_results_data_str # from simple answer path only
|
||||
if consolidated_research_object_results_str:
|
||||
research_results = consolidated_research_object_results_str
|
||||
else:
|
||||
research_results = ""
|
||||
|
||||
if introductory_answer:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
|
||||
.replace(
|
||||
"---introductory_answer---",
|
||||
rename_entities_in_answer(introductory_answer),
|
||||
)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
)
|
||||
elif research_results and consolidated_research_object_results_str:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
|
||||
.replace(
|
||||
"---introductory_answer---",
|
||||
rename_entities_in_answer(consolidated_research_object_results_str),
|
||||
)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
)
|
||||
elif research_results and not consolidated_research_object_results_str:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT.replace("---question---", question)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
.replace(
|
||||
"---research_results---", rename_entities_in_answer(research_results)
|
||||
)
|
||||
)
|
||||
elif consolidated_research_object_results_str:
|
||||
output_format_prompt = (
|
||||
OUTPUT_FORMAT_NO_EXAMPLES_PROMPT.replace("---question---", question)
|
||||
.replace("---output_format---", str(output_format) if output_format else "")
|
||||
.replace(
|
||||
"---research_results---", rename_entities_in_answer(research_results)
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("No research results or introductory answer provided")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=output_format_prompt,
|
||||
)
|
||||
]
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
|
||||
dispatch_timings: list[float] = []
|
||||
response: list[str] = []
|
||||
|
||||
def stream_answer() -> list[str]:
|
||||
# Get the LLM's tokenizer
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=fast_llm.config.model_name,
|
||||
provider_type=fast_llm.config.model_provider,
|
||||
)
|
||||
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=1000,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
|
||||
# Tokenize the content using the LLM's tokenizer
|
||||
tokens = llm_tokenizer.tokenize(content)
|
||||
for token in tokens:
|
||||
start_stream_token = datetime.now()
|
||||
stream_write_main_answer_token(
|
||||
writer, token, level=0, level_question_num=0
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(token)
|
||||
return response
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
30,
|
||||
stream_answer,
|
||||
)
|
||||
|
||||
if reference_results_str:
|
||||
# Get the LLM's tokenizer
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=fast_llm.config.model_name,
|
||||
provider_type=fast_llm.config.model_provider,
|
||||
)
|
||||
|
||||
# Tokenize and stream the reference results
|
||||
tokens = llm_tokenizer.tokenize(reference_results_str)
|
||||
for token in tokens:
|
||||
# Replace newlines with HTML line breaks
|
||||
# if token == ' \n':
|
||||
# token = ' <br>'
|
||||
stream_write_main_answer_token(writer, token)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not generate the answer. Error {e}")
|
||||
|
||||
stream_write_close_main_answer(writer)
|
||||
|
||||
# Persist the sub-answer in the database
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
chat_session_id = graph_config.persistence.chat_session_id
|
||||
primary_message_id = graph_config.persistence.message_id
|
||||
sub_question_answer_results = state.step_results
|
||||
|
||||
log_agent_sub_question_results(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
primary_message_id=primary_message_id,
|
||||
sub_question_answer_results=sub_question_answer_results,
|
||||
)
|
||||
|
||||
return MainOutput(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="query completed",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
65
backend/onyx/agents/agent_search/kb_search/ops.py
Normal file
65
backend/onyx/agents/agent_search/kb_search/ops.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.kg_configs import KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
def research(
|
||||
question: str,
|
||||
search_tool: SearchTool,
|
||||
document_sources: list[DocumentSource] | None = None,
|
||||
time_cutoff: datetime | None = None,
|
||||
kg_entities: list[str] | None = None,
|
||||
kg_relationships: list[str] | None = None,
|
||||
kg_terms: list[str] | None = None,
|
||||
kg_sources: list[str] | None = None,
|
||||
kg_chunk_id_zero_only: bool = False,
|
||||
inference_sections_only: bool = False,
|
||||
) -> list[LlmDoc] | list[InferenceSection]:
|
||||
# new db session to avoid concurrency issues
|
||||
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
retrieved_docs: list[LlmDoc] | list[InferenceSection] = []
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=False,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
document_sources=document_sources,
|
||||
time_cutoff=time_cutoff,
|
||||
kg_entities=kg_entities,
|
||||
kg_relationships=kg_relationships,
|
||||
kg_terms=kg_terms,
|
||||
kg_sources=kg_sources,
|
||||
kg_chunk_id_zero_only=kg_chunk_id_zero_only,
|
||||
),
|
||||
):
|
||||
if (
|
||||
inference_sections_only
|
||||
and tool_response.id == "search_response_summary"
|
||||
):
|
||||
retrieved_docs = tool_response.response.top_sections[
|
||||
:KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
]
|
||||
retrieved_docs = cast(list[InferenceSection], retrieved_docs)
|
||||
break
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
elif tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
retrieved_docs = cast(list[LlmDoc], tool_response.response)[
|
||||
:KG_RESEARCH_NUM_RETRIEVED_DOCS
|
||||
]
|
||||
break
|
||||
return retrieved_docs
|
||||
163
backend/onyx/agents/agent_search/kb_search/states.py
Normal file
163
backend/onyx/agents/agent_search/kb_search/states.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from enum import Enum
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import SubQuestionAnswerResults
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
|
||||
|
||||
class StepResults(BaseModel):
|
||||
question: str
|
||||
question_id: str
|
||||
answer: str
|
||||
sub_query_retrieval_results: list[QueryRetrievalResult]
|
||||
verified_reranked_documents: list[InferenceSection]
|
||||
context_documents: list[InferenceSection]
|
||||
cited_documents: list[InferenceSection]
|
||||
|
||||
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
step_results: Annotated[list[SubQuestionAnswerResults], add]
|
||||
|
||||
|
||||
class KGFilterConstructionResults(BaseModel):
|
||||
global_entity_filters: list[str]
|
||||
global_relationship_filters: list[str]
|
||||
local_entity_filters: list[list[str]]
|
||||
source_document_filters: list[str]
|
||||
structure: list[str]
|
||||
|
||||
|
||||
class KGSearchType(Enum):
|
||||
SEARCH = "SEARCH"
|
||||
SQL = "SQL"
|
||||
|
||||
|
||||
class KGAnswerStrategy(Enum):
|
||||
DEEP = "DEEP"
|
||||
SIMPLE = "SIMPLE"
|
||||
|
||||
|
||||
class KGAnswerFormat(Enum):
|
||||
LIST = "LIST"
|
||||
TEXT = "TEXT"
|
||||
|
||||
|
||||
class YesNoEnum(str, Enum):
|
||||
YES = "yes"
|
||||
NO = "no"
|
||||
|
||||
|
||||
class AnalysisUpdate(LoggerUpdate):
|
||||
normalized_core_entities: list[str] = []
|
||||
normalized_core_relationships: list[str] = []
|
||||
query_graph_entities_no_attributes: list[str] = []
|
||||
query_graph_entities_w_attributes: list[str] = []
|
||||
query_graph_relationships: list[str] = []
|
||||
normalized_terms: list[str] = []
|
||||
normalized_time_filter: str | None = None
|
||||
strategy: KGAnswerStrategy | None = None
|
||||
output_format: KGAnswerFormat | None = None
|
||||
broken_down_question: str | None = None
|
||||
divide_and_conquer: YesNoEnum | None = None
|
||||
single_doc_id: str | None = None
|
||||
search_type: KGSearchType | None = None
|
||||
|
||||
|
||||
class SQLSimpleGenerationUpdate(LoggerUpdate):
|
||||
sql_query: str | None = None
|
||||
sql_query_results: list[Dict[Any, Any]] | None = None
|
||||
individualized_sql_query: str | None = None
|
||||
individualized_query_results: list[Dict[Any, Any]] | None = None
|
||||
source_documents_sql: str | None = None
|
||||
source_document_results: list[str] | None = None
|
||||
|
||||
|
||||
class ConsolidatedResearchUpdate(LoggerUpdate):
|
||||
consolidated_research_object_results_str: str | None = None
|
||||
|
||||
|
||||
class DeepSearchFilterUpdate(LoggerUpdate):
|
||||
vespa_filter_results: KGFilterConstructionResults | None = None
|
||||
div_con_entities: list[str] | None = None
|
||||
source_division: bool | None = None
|
||||
global_entity_filters: list[str] | None = None
|
||||
global_relationship_filters: list[str] | None = None
|
||||
local_entity_filters: list[list[str]] | None = None
|
||||
source_filters: list[str] | None = None
|
||||
|
||||
|
||||
class ResearchObjectOutput(LoggerUpdate):
|
||||
research_object_results: Annotated[list[dict[str, Any]], add] = []
|
||||
|
||||
|
||||
class ERTExtractionUpdate(LoggerUpdate):
|
||||
entities_types_str: str = ""
|
||||
relationship_types_str: str = ""
|
||||
extracted_entities_w_attributes: list[str] = []
|
||||
extracted_entities_no_attributes: list[str] = []
|
||||
extracted_relationships: list[str] = []
|
||||
extracted_terms: list[str] = []
|
||||
time_filter: str | None = None
|
||||
|
||||
|
||||
class ResultsDataUpdate(LoggerUpdate):
|
||||
query_results_data_str: str | None = None
|
||||
individualized_query_results_data_str: str | None = None
|
||||
reference_results_str: str | None = None
|
||||
|
||||
|
||||
class ResearchObjectUpdate(LoggerUpdate):
|
||||
research_object_results: Annotated[list[dict[str, Any]], add] = []
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
ERTExtractionUpdate,
|
||||
AnalysisUpdate,
|
||||
SQLSimpleGenerationUpdate,
|
||||
ResultsDataUpdate,
|
||||
ResearchObjectOutput,
|
||||
DeepSearchFilterUpdate,
|
||||
ResearchObjectUpdate,
|
||||
ConsolidatedResearchUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
|
||||
|
||||
class ResearchObjectInput(LoggerUpdate):
|
||||
research_nr: int
|
||||
entity: str
|
||||
broken_down_question: str
|
||||
vespa_filter_results: KGFilterConstructionResults
|
||||
source_division: bool | None
|
||||
source_entity_filters: list[str] | None
|
||||
@@ -0,0 +1,29 @@
|
||||
from onyx.agents.agent_search.kb_search.models import KGSteps
|
||||
|
||||
STEP_DESCRIPTIONS: dict[int, KGSteps] = {
|
||||
1: KGSteps(
|
||||
description="Analyzing the question...",
|
||||
activities=[
|
||||
"Entities in Query",
|
||||
"Relationships in Query",
|
||||
"Terms in Query",
|
||||
"Time Filters",
|
||||
],
|
||||
),
|
||||
2: KGSteps(
|
||||
description="Planning the response approach...",
|
||||
activities=["Query Execution Strategy", "Answer Format"],
|
||||
),
|
||||
3: KGSteps(
|
||||
description="Querying the Knowledge Graph..",
|
||||
activities=[
|
||||
"Knowledge Graph Query",
|
||||
"Knowledge Graph Query Results",
|
||||
"Query for Source Documents",
|
||||
"Source Documents",
|
||||
],
|
||||
),
|
||||
4: KGSteps(
|
||||
description="Conducting further research on source documents...", activities=[]
|
||||
),
|
||||
}
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.kg.models import KGConfigSettings
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
@@ -68,6 +69,7 @@ class GraphSearchConfig(BaseModel):
|
||||
allow_refinement: bool = True
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
allow_agent_reranking: bool = False
|
||||
kg_config_settings: KGConfigSettings = KGConfigSettings()
|
||||
|
||||
|
||||
class GraphConfig(BaseModel):
|
||||
|
||||
@@ -18,6 +18,8 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
MainInput as MainInput,
|
||||
)
|
||||
from onyx.agents.agent_search.kb_search.graph_builder import kb_graph_builder
|
||||
from onyx.agents.agent_search.kb_search.states import MainInput as KBMainInput
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
@@ -85,7 +87,7 @@ def _parse_agent_event(
|
||||
def manage_sync_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
graph_input: BasicInput | MainInput | DCMainInput,
|
||||
graph_input: BasicInput | MainInput | DCMainInput | KBMainInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
message_id = config.persistence.message_id if config.persistence else None
|
||||
for event in compiled_graph.stream(
|
||||
@@ -99,7 +101,7 @@ def manage_sync_streaming(
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
config: GraphConfig,
|
||||
input: BasicInput | MainInput | DCMainInput,
|
||||
input: BasicInput | MainInput | DCMainInput | KBMainInput,
|
||||
) -> AnswerStream:
|
||||
config.behavior.perform_initial_search_decomposition = (
|
||||
INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
@@ -149,6 +151,21 @@ def run_basic_graph(
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_kb_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
graph = kb_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = KBMainInput(log_messages=[])
|
||||
|
||||
yield ToolCallKickoff(
|
||||
tool_name="agent_search_0",
|
||||
tool_args={"query": config.inputs.search_request.query},
|
||||
)
|
||||
|
||||
yield from run_graph(compiled_graph, config, input)
|
||||
|
||||
|
||||
def run_dc_graph(
|
||||
config: GraphConfig,
|
||||
) -> AnswerStream:
|
||||
|
||||
@@ -159,3 +159,8 @@ BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
class QueryExpansionType(Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
|
||||
|
||||
class ReferenceResults(BaseModel):
|
||||
citations: list[str]
|
||||
general_entities: list[str]
|
||||
|
||||
@@ -23,6 +23,7 @@ from onyx.db.document import mark_document_as_modified
|
||||
from onyx.db.document import mark_document_as_synced
|
||||
from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.relationships import delete_document_references_from_kg
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
@@ -119,6 +120,11 @@ def document_by_cc_pair_cleanup_task(
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
delete_document_references_from_kg(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import run_dc_graph
|
||||
from onyx.agents.agent_search.run_graph import run_kb_graph
|
||||
from onyx.agents.agent_search.run_graph import run_main_graph
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
@@ -22,8 +23,10 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionKey
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.chat_configs import USE_DIV_CON_AGENT
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
@@ -127,8 +130,9 @@ class Answer:
|
||||
self.search_behavior_config = GraphSearchConfig(
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
allow_refinement=True,
|
||||
allow_refinement=False,
|
||||
allow_agent_reranking=allow_agent_reranking,
|
||||
kg_config_settings=get_kg_config_settings(db_session),
|
||||
)
|
||||
self.graph_config = GraphConfig(
|
||||
inputs=self.graph_inputs,
|
||||
@@ -143,10 +147,19 @@ class Answer:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
|
||||
if self.graph_config.behavior.use_agentic_search:
|
||||
if self.graph_config.behavior.use_agentic_search and (
|
||||
self.graph_config.inputs.search_request.persona
|
||||
and self.graph_config.behavior.kg_config_settings.KG_ENABLED
|
||||
and self.graph_config.inputs.search_request.persona.name.startswith(
|
||||
"KG Dev"
|
||||
)
|
||||
):
|
||||
run_langgraph = run_kb_graph
|
||||
elif self.graph_config.behavior.use_agentic_search:
|
||||
run_langgraph = run_main_graph
|
||||
elif (
|
||||
self.graph_config.inputs.search_request.persona
|
||||
and USE_DIV_CON_AGENT
|
||||
and self.graph_config.inputs.search_request.persona.description.startswith(
|
||||
"DivCon Beta Agent"
|
||||
)
|
||||
|
||||
@@ -81,6 +81,7 @@ from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from onyx.db.chat import update_chat_session_updated_at_timestamp
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.milestone import check_multi_assistant_milestone
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
from onyx.db.milestone import update_user_assistant_milestone
|
||||
@@ -100,6 +101,16 @@ from onyx.file_store.utils import get_user_files
|
||||
from onyx.file_store.utils import load_all_chat_files
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.kg.clustering.incremental_cluster_updates import (
|
||||
kg_incremental_cluster_updates,
|
||||
)
|
||||
from onyx.kg.clustering.initial_clustering import kg_clustering
|
||||
from onyx.kg.configuration import populate_default_account_employee_definitions
|
||||
from onyx.kg.configuration import populate_default_grounded_entity_types
|
||||
from onyx.kg.extractions.extraction_processing import kg_extraction
|
||||
from onyx.kg.resets.reset_extractions import reset_extraction_kg_index
|
||||
from onyx.kg.resets.reset_index import reset_full_kg_index
|
||||
from onyx.kg.resets.reset_normalizations import reset_normalization_kg_index
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
@@ -664,6 +675,46 @@ def stream_chat_message_objects(
|
||||
|
||||
llm: LLM
|
||||
|
||||
kg_config_settings = get_kg_config_settings(db_session)
|
||||
|
||||
if kg_config_settings.KG_ENABLED:
|
||||
|
||||
# Temporarily, until we have a draft UI for the KG Operations/Management
|
||||
|
||||
# get Vespa index
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
index_str = search_settings.index_name
|
||||
|
||||
if new_msg_req.message == "kg_e":
|
||||
kg_extraction(tenant_id, index_str)
|
||||
|
||||
raise Exception("Extractions done")
|
||||
|
||||
elif new_msg_req.message == "kg_c":
|
||||
kg_clustering(tenant_id, index_str)
|
||||
raise Exception("Clustering done")
|
||||
|
||||
elif new_msg_req.message == "kg_i":
|
||||
kg_incremental_cluster_updates(tenant_id, index_str)
|
||||
raise Exception("Incremental clustering done")
|
||||
|
||||
elif new_msg_req.message == "kg_rs_full":
|
||||
reset_full_kg_index()
|
||||
raise Exception("Full KG index reset done")
|
||||
|
||||
elif new_msg_req.message == "kg_rs_extraction":
|
||||
reset_extraction_kg_index()
|
||||
raise Exception("Extraction KG index reset done")
|
||||
|
||||
elif new_msg_req.message == "kg_rs_normalization":
|
||||
reset_normalization_kg_index()
|
||||
raise Exception("Normalization KG index reset done")
|
||||
|
||||
elif new_msg_req.message == "kg_setup":
|
||||
populate_default_grounded_entity_types()
|
||||
populate_default_account_employee_definitions()
|
||||
raise Exception("KG setup done")
|
||||
|
||||
try:
|
||||
# Move these variables inside the try block
|
||||
file_id_to_user_file = {}
|
||||
|
||||
@@ -184,6 +184,13 @@ POSTGRES_API_SERVER_POOL_OVERFLOW = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10
|
||||
)
|
||||
|
||||
POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE") or 10
|
||||
)
|
||||
POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW") or 5
|
||||
)
|
||||
|
||||
# defaults to False
|
||||
# generally should only be used for
|
||||
POSTGRES_USE_NULL_POOL = os.environ.get("POSTGRES_USE_NULL_POOL", "").lower() == "true"
|
||||
@@ -746,3 +753,9 @@ IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
|
||||
DISABLE_AUTO_AUTH_REFRESH = (
|
||||
os.environ.get("DISABLE_AUTO_AUTH_REFRESH", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Knowledge Graph Read Only User Configuration
|
||||
DB_READONLY_USER: str = os.environ.get("DB_READONLY_USER", "db_readonly_user")
|
||||
DB_READONLY_PASSWORD: str = urllib.parse.quote_plus(
|
||||
os.environ.get("DB_READONLY_PASSWORD") or "password"
|
||||
)
|
||||
|
||||
@@ -102,3 +102,5 @@ USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH = (
|
||||
os.environ.get("USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH", "false").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
USE_DIV_CON_AGENT = os.environ.get("USE_DIV_CON_AGENT", "false").lower() == "true"
|
||||
|
||||
@@ -468,3 +468,8 @@ if platform.system() == "Darwin":
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
|
||||
else:
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore
|
||||
|
||||
|
||||
class OnyxCallTypes(str, Enum):
|
||||
FIREFLIES = "fireflies"
|
||||
GONG = "gong"
|
||||
|
||||
10
backend/onyx/configs/kg_configs.py
Normal file
10
backend/onyx/configs/kg_configs.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import os
|
||||
|
||||
KG_RESEARCH_NUM_RETRIEVED_DOCS: int = int(
|
||||
os.environ.get("KG_RESEARCH_NUM_RETRIEVED_DOCS", "25")
|
||||
)
|
||||
|
||||
|
||||
KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES: int = int(
|
||||
os.environ.get("KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES", "10")
|
||||
)
|
||||
@@ -193,12 +193,19 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
team {
|
||||
name
|
||||
}
|
||||
assignee {
|
||||
email
|
||||
}
|
||||
previousIdentifiers
|
||||
subIssueSortOrder
|
||||
priorityLabel
|
||||
identifier
|
||||
url
|
||||
branchName
|
||||
state {
|
||||
id
|
||||
name
|
||||
}
|
||||
customerTicketCount
|
||||
description
|
||||
comments {
|
||||
@@ -267,7 +274,19 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
title=node["title"],
|
||||
doc_updated_at=time_str_to_utc(node["updatedAt"]),
|
||||
metadata={
|
||||
"team": node["team"]["name"],
|
||||
k: str(v)
|
||||
for k, v in {
|
||||
"team": (node.get("team") or {}).get("name"),
|
||||
"assignee": (node.get("assignee") or {}).get("email"),
|
||||
"state": (node.get("state") or {}).get("name"),
|
||||
"priority": node.get("priority"),
|
||||
"estimate": node.get("estimate"),
|
||||
"started_at": node.get("startedAt"),
|
||||
"completed_at": node.get("completedAt"),
|
||||
"created_at": node.get("createdAt"),
|
||||
"due_date": node.get("dueDate"),
|
||||
}.items()
|
||||
if v is not None
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -35,6 +35,28 @@ logger = setup_logger()
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
|
||||
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
|
||||
"Opportunity": {
|
||||
"Account": "account",
|
||||
"FiscalQuarter": "fiscal_quarter",
|
||||
"FiscalYear": "fiscal_year",
|
||||
"IsClosed": "is_closed",
|
||||
"Name": "name",
|
||||
"StageName": "stage_name",
|
||||
"Type": "type",
|
||||
"Amount": "amount",
|
||||
"CloseDate": "close_date",
|
||||
"Probability": "probability",
|
||||
"CreatedDate": "created_date",
|
||||
"LastModifiedDate": "last_modified_date",
|
||||
},
|
||||
"Contact": {
|
||||
"Account": "account",
|
||||
"CreatedDate": "created_date",
|
||||
"LastModifiedDate": "last_modified_date",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
"""Approach outline
|
||||
@@ -168,6 +190,8 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
# Always want to make sure user is grabbed for permissioning purposes
|
||||
all_types.add("User")
|
||||
# Always want to make sure account is grabbed for reference purposes
|
||||
all_types.add("Account")
|
||||
|
||||
logger.info(f"All object types: num={len(all_types)} list={all_types}")
|
||||
|
||||
@@ -279,7 +303,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
f"remaining={len(updated_ids) - docs_processed}"
|
||||
)
|
||||
for parent_id in parent_id_batch:
|
||||
parent_object = sf_db.get_record(parent_id, parent_type)
|
||||
parent_object = sf_db.get_record(
|
||||
parent_id, parent_type, isChild=False
|
||||
)
|
||||
if not parent_object:
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
@@ -291,6 +317,18 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
doc.metadata["object_type"] = parent_type
|
||||
|
||||
# Add default attributes to the metadata
|
||||
for (
|
||||
sf_attribute,
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
|
||||
if sf_attribute in parent_object.data:
|
||||
doc.metadata[canonical_attribute] = parent_object.data[
|
||||
sf_attribute
|
||||
]
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
docs_to_yield.append(doc)
|
||||
|
||||
@@ -172,7 +172,7 @@ def convert_sf_object_to_doc(
|
||||
|
||||
sections = [_extract_section(sf_object, base_url)]
|
||||
for id in sf_db.get_child_ids(sf_object.id):
|
||||
if not (child_object := sf_db.get_record(id)):
|
||||
if not (child_object := sf_db.get_record(id, isChild=True)):
|
||||
continue
|
||||
sections.append(_extract_section(child_object, base_url))
|
||||
|
||||
|
||||
@@ -456,7 +456,7 @@ class OnyxSalesforceSQLite:
|
||||
return result[0]
|
||||
|
||||
def get_record(
|
||||
self, object_id: str, object_type: str | None = None
|
||||
self, object_id: str, object_type: str | None = None, isChild: bool = False
|
||||
) -> SalesforceObject | None:
|
||||
"""Retrieve the record and return it as a SalesforceObject."""
|
||||
if self._conn is None:
|
||||
@@ -469,15 +469,44 @@ class OnyxSalesforceSQLite:
|
||||
|
||||
with self._conn:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT data FROM salesforce_objects WHERE id = ?", (object_id,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
# Get the object data and account data
|
||||
if object_type == "Account" or isChild:
|
||||
cursor.execute(
|
||||
"SELECT data FROM salesforce_objects WHERE id = ?", (object_id,)
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"SELECT pso.data, r.parent_id as parent_id, sso.object_type FROM salesforce_objects pso \
|
||||
LEFT JOIN relationships r on r.child_id = pso.id \
|
||||
LEFT JOIN salesforce_objects sso on r.parent_id = sso.id \
|
||||
WHERE pso.id = ? ",
|
||||
(object_id,),
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
if not result:
|
||||
logger.warning(f"Object ID {object_id} not found")
|
||||
return None
|
||||
|
||||
data = json.loads(result[0])
|
||||
data = json.loads(result[0][0])
|
||||
|
||||
if object_type != "Account":
|
||||
|
||||
# convert any account ids of the relationships back into data fields, with name
|
||||
for row in result:
|
||||
|
||||
# the following skips Account objects.
|
||||
if len(row) < 3:
|
||||
continue
|
||||
|
||||
if row[1] and row[2] and row[2] == "Account":
|
||||
data["AccountId"] = row[1]
|
||||
cursor.execute(
|
||||
"SELECT data FROM salesforce_objects WHERE id = ?",
|
||||
(row[1],),
|
||||
)
|
||||
account_data = json.loads(cursor.fetchone()[0])
|
||||
data["Account"] = account_data.get("Name", "")
|
||||
|
||||
return SalesforceObject(id=object_id, type=object_type, data=data)
|
||||
|
||||
def find_ids_by_type(self, object_type: str) -> list[str]:
|
||||
|
||||
@@ -113,6 +113,11 @@ class BaseFilters(BaseModel):
|
||||
tags: list[Tag] | None = None
|
||||
user_file_ids: list[int] | None = None
|
||||
user_folder_ids: list[int] | None = None
|
||||
kg_entities: list[str] | None = None
|
||||
kg_relationships: list[str] | None = None
|
||||
kg_terms: list[str] | None = None
|
||||
kg_sources: list[str] | None = None
|
||||
kg_chunk_id_zero_only: bool | None = False
|
||||
|
||||
|
||||
class IndexFilters(BaseFilters):
|
||||
|
||||
@@ -182,6 +182,11 @@ def retrieval_preprocessing(
|
||||
tags=preset_filters.tags, # Tags are never auto-extracted
|
||||
access_control_list=user_acl_filters,
|
||||
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
|
||||
kg_entities=preset_filters.kg_entities,
|
||||
kg_relationships=preset_filters.kg_relationships,
|
||||
kg_terms=preset_filters.kg_terms,
|
||||
kg_sources=preset_filters.kg_sources,
|
||||
kg_chunk_id_zero_only=preset_filters.kg_chunk_id_zero_only,
|
||||
)
|
||||
|
||||
llm_evaluation_type = LLMEvaluationType.BASIC
|
||||
|
||||
@@ -16,6 +16,7 @@ from onyx.db.enums import IndexingMode
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.kg.models import KGConnectorData
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
from onyx.server.documents.models import ObjectCreationIdResponse
|
||||
from onyx.server.models import StatusResponse
|
||||
@@ -334,3 +335,29 @@ def mark_ccpair_with_indexing_trigger(
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_kg_enabled_connectors(db_session: Session) -> list[KGConnectorData]:
|
||||
"""
|
||||
Retrieves a list of connector IDs that have not been KG processed for a given tenant.
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
Returns:
|
||||
list[int]: List of connector IDs that have enabled KG extraction but have unprocessed documents
|
||||
"""
|
||||
try:
|
||||
stmt = select(Connector.id, Connector.source).where(
|
||||
Connector.kg_processing_enabled
|
||||
)
|
||||
result = db_session.execute(stmt)
|
||||
|
||||
connector_results = [
|
||||
KGConnectorData(id=row[0], source=row[1].lower())
|
||||
for row in result.fetchall()
|
||||
]
|
||||
|
||||
return connector_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching unprocessed connector IDs: {str(e)}")
|
||||
raise e
|
||||
|
||||
@@ -21,23 +21,33 @@ from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import null
|
||||
|
||||
from onyx.agents.agent_search.kb_search.models import KGEntityDocInfo
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.kg_configs import KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES
|
||||
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.entities import delete_from_kg_entities__no_commit
|
||||
from onyx.db.entities import delete_from_kg_entities_extraction_staging__no_commit
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import Document as DbDocument
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import User
|
||||
from onyx.db.relationships import delete_from_kg_relationships__no_commit
|
||||
from onyx.db.relationships import (
|
||||
delete_from_kg_relationships_extraction_staging__no_commit,
|
||||
)
|
||||
from onyx.db.tag import delete_document_tags_for_documents__no_commit
|
||||
from onyx.db.utils import model_to_dict
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.kg.models import KGStage
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -392,6 +402,7 @@ def upsert_documents(
|
||||
last_modified=datetime.now(timezone.utc),
|
||||
primary_owners=doc.primary_owners,
|
||||
secondary_owners=doc.secondary_owners,
|
||||
kg_stage=KGStage.NOT_STARTED,
|
||||
)
|
||||
)
|
||||
for doc in seen_documents.values()
|
||||
@@ -605,7 +616,29 @@ def delete_documents_complete__no_commit(
|
||||
) -> None:
|
||||
"""This completely deletes the documents from the db, including all foreign key relationships"""
|
||||
|
||||
# Start by deleting the chunk stats for the documents
|
||||
# Start with the kg references
|
||||
|
||||
delete_from_kg_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
delete_from_kg_entities__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
delete_from_kg_relationships_extraction_staging__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
delete_from_kg_entities_extraction_staging__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
# Continue with deleting the chunk stats for the documents
|
||||
delete_chunk_stats_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
@@ -858,3 +891,262 @@ def fetch_chunk_count_for_document(
|
||||
) -> int | None:
|
||||
stmt = select(DbDocument.chunk_count).where(DbDocument.id == document_id)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_unprocessed_kg_document_batch_for_connector(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
batch_size: int = 100,
|
||||
) -> list[DbDocument]:
|
||||
"""
|
||||
Retrieves a batch of documents that have not been processed for knowledge graph extraction.
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
connector_id (int): The ID of the connector to get documents for
|
||||
batch_size (int): The maximum number of documents to retrieve
|
||||
Returns:
|
||||
list[DbDocument]: List of documents that need KG processing
|
||||
"""
|
||||
|
||||
stmt = (
|
||||
select(DbDocument)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DbDocument.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
or_(
|
||||
DbDocument.kg_stage.is_(None),
|
||||
DbDocument.kg_stage == KGStage.NOT_STARTED,
|
||||
),
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
.order_by(DbDocument.doc_updated_at.desc())
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
documents = db_session.scalars(stmt).all()
|
||||
db_session.flush()
|
||||
|
||||
return list(documents)
|
||||
|
||||
|
||||
def get_kg_extracted_document_ids(db_session: Session) -> list[str]:
|
||||
"""
|
||||
Retrieves all document IDs where kg_stage is EXTRACTED.
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
Returns:
|
||||
list[str]: List of document IDs that have been KG processed
|
||||
"""
|
||||
stmt = select(DbDocument.id).where(DbDocument.kg_stage == KGStage.EXTRACTED)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def update_document_kg_info(
|
||||
db_session: Session, document_id: str, kg_stage: KGStage
|
||||
) -> None:
|
||||
"""Updates the knowledge graph related information for a document.
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
document_id (str): The ID of the document to update
|
||||
kg_stage (KGStage): The stage of the knowledge graph processing for the document
|
||||
Raises:
|
||||
ValueError: If the document with the given ID is not found
|
||||
"""
|
||||
stmt = (
|
||||
update(DbDocument)
|
||||
.where(DbDocument.id == document_id)
|
||||
.values(
|
||||
kg_stage=kg_stage,
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def update_document_kg_stage(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
kg_stage: KGStage,
|
||||
) -> None:
|
||||
stmt = (
|
||||
update(DbDocument).where(DbDocument.id == document_id).values(kg_stage=kg_stage)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def get_document_kg_info(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
) -> KGStage | None:
|
||||
"""Retrieves the knowledge graph processing status and data for a document.
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
document_id (str): The ID of the document to query
|
||||
Returns:
|
||||
Optional[Tuple[bool, dict]]: A tuple containing:
|
||||
- bool: Whether the document has been KG processed
|
||||
- dict: The KG data containing 'entities', 'relationships', and 'terms'
|
||||
Returns None if the document is not found
|
||||
"""
|
||||
stmt = select(DbDocument.kg_stage).where(DbDocument.id == document_id)
|
||||
result = db_session.execute(stmt).one_or_none()
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return result.kg_stage or KGStage.NOT_STARTED
|
||||
|
||||
|
||||
def get_all_kg_extracted_documents_info(
|
||||
db_session: Session,
|
||||
) -> list[str]:
|
||||
"""Retrieves the knowledge graph data for all documents that have been processed.
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
Returns:
|
||||
List[Tuple[str, dict]]: A list of tuples containing:
|
||||
- str: The document ID
|
||||
- dict: The KG data containing 'entities', 'relationships', and 'terms'
|
||||
Only returns documents where kg_stage is EXTRACTED
|
||||
"""
|
||||
stmt = (
|
||||
select(DbDocument.id)
|
||||
.where(DbDocument.kg_stage == KGStage.EXTRACTED)
|
||||
.order_by(DbDocument.id)
|
||||
)
|
||||
|
||||
results = db_session.execute(stmt).all()
|
||||
return [str(doc_id) for doc_id in results]
|
||||
|
||||
|
||||
def get_base_llm_doc_information(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> list[str]:
|
||||
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
|
||||
results = db_session.execute(stmt).all()
|
||||
|
||||
documents = []
|
||||
|
||||
for doc_nr, doc in enumerate(results):
|
||||
bare_doc = doc[0]
|
||||
documents.append(
|
||||
f"""* [{bare_doc.semantic_id}]({bare_doc.link}) ({bare_doc.doc_updated_at})"""
|
||||
)
|
||||
|
||||
return documents[:KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES]
|
||||
|
||||
|
||||
def get_document_updated_at(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> datetime | None:
|
||||
"""Retrieves the doc_updated_at timestamp for a given document ID.
|
||||
Args:
|
||||
document_id (str): The ID of the document to query
|
||||
db_session (Session): The database session to use
|
||||
Returns:
|
||||
Optional[datetime]: The doc_updated_at timestamp if found, None if document doesn't exist
|
||||
"""
|
||||
if len(document_id.split(":")) == 2:
|
||||
document_id = document_id.split(":")[1]
|
||||
elif len(document_id.split(":")) > 2:
|
||||
raise ValueError(f"Invalid document ID: {document_id}")
|
||||
else:
|
||||
pass
|
||||
|
||||
stmt = select(DbDocument.doc_updated_at).where(DbDocument.id == document_id)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def reset_all_document_kg_stages(db_session: Session) -> int:
|
||||
"""Reset the KG stage of all documents that are not in NOT_STARTED state to NOT_STARTED.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
|
||||
Returns:
|
||||
int: Number of documents that were reset
|
||||
"""
|
||||
stmt = (
|
||||
update(DbDocument)
|
||||
.where(DbDocument.kg_stage != KGStage.NOT_STARTED)
|
||||
.values(kg_stage=KGStage.NOT_STARTED)
|
||||
)
|
||||
result = db_session.execute(stmt)
|
||||
|
||||
# The hasattr check is needed for type checking, even though rowcount
|
||||
# is guaranteed to exist at runtime for UPDATE operations
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0
|
||||
|
||||
|
||||
def update_document_kg_stages(
|
||||
db_session: Session, source_stage: KGStage, target_stage: KGStage
|
||||
) -> int:
|
||||
"""Reset the KG stage only of documents back to NOT_STARTED.
|
||||
Part of reset flow for documemnts that have been extracted but not clustered.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
|
||||
Returns:
|
||||
int: Number of documents that were reset
|
||||
"""
|
||||
stmt = (
|
||||
update(DbDocument)
|
||||
.where(DbDocument.kg_stage == source_stage)
|
||||
.values(kg_stage=target_stage)
|
||||
)
|
||||
result = db_session.execute(stmt)
|
||||
# The hasattr check is needed for type checking, even though rowcount
|
||||
# is guaranteed to exist at runtime for UPDATE operations
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0
|
||||
|
||||
|
||||
def get_skipped_kg_documents(db_session: Session) -> list[str]:
|
||||
"""
|
||||
Retrieves all document IDs where kg_stage is SKIPPED.
|
||||
Args:
|
||||
db_session (Session): The database session to use
|
||||
Returns:
|
||||
list[str]: List of document IDs that have been skipped in KG processing
|
||||
"""
|
||||
stmt = select(DbDocument.id).where(DbDocument.kg_stage == KGStage.SKIPPED)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def get_kg_doc_info_for_entity_name(
|
||||
db_session: Session, document_id: str, entity_type: str
|
||||
) -> KGEntityDocInfo:
|
||||
"""
|
||||
Get the semantic ID and the link for an entity name.
|
||||
"""
|
||||
|
||||
result = (
|
||||
db_session.query(Document.semantic_id, Document.link)
|
||||
.filter(Document.id == document_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return KGEntityDocInfo(
|
||||
doc_id=None,
|
||||
doc_semantic_id=None,
|
||||
doc_link=None,
|
||||
semantic_entity_name=f"{entity_type}:{document_id}",
|
||||
semantic_linked_entity_name=f"{entity_type}:{document_id}",
|
||||
)
|
||||
|
||||
return KGEntityDocInfo(
|
||||
doc_id=document_id,
|
||||
doc_semantic_id=result[0],
|
||||
doc_link=result[1],
|
||||
semantic_entity_name=f"{entity_type.upper()}:{result.semantic_id}",
|
||||
semantic_linked_entity_name=f"[{entity_type.upper()}:{result.semantic_id}]({result[1]})",
|
||||
)
|
||||
|
||||
@@ -27,6 +27,8 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from onyx.configs.app_configs import AWS_REGION_NAME
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
@@ -187,7 +189,9 @@ def is_valid_schema_name(name: str) -> bool:
|
||||
|
||||
class SqlEngine:
|
||||
_engine: Engine | None = None
|
||||
_readonly_engine: Engine | None = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_readonly_lock: threading.Lock = threading.Lock()
|
||||
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
||||
|
||||
@classmethod
|
||||
@@ -252,12 +256,80 @@ class SqlEngine:
|
||||
|
||||
cls._engine = engine
|
||||
|
||||
@classmethod
|
||||
def init_readonly_engine(
|
||||
cls,
|
||||
pool_size: int,
|
||||
# is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy
|
||||
max_overflow: int,
|
||||
**extra_engine_kwargs: Any,
|
||||
) -> None:
|
||||
"""NOTE: enforce that pool_size and pool_max_overflow are passed in. These are
|
||||
important args, and if incorrectly specified, we have run into hitting the pool
|
||||
limit / using too many connections and overwhelming the database."""
|
||||
with cls._readonly_lock:
|
||||
if cls._readonly_engine:
|
||||
return
|
||||
|
||||
if not DB_READONLY_USER or not DB_READONLY_PASSWORD:
|
||||
raise ValueError(
|
||||
"Custom database user credentials not configured in environment variables"
|
||||
)
|
||||
|
||||
# Build connection string with custom user
|
||||
connection_string = build_connection_string(
|
||||
user=DB_READONLY_USER,
|
||||
password=DB_READONLY_PASSWORD,
|
||||
use_iam_auth=False, # Custom users typically don't use IAM auth
|
||||
db_api=SYNC_DB_API, # Explicitly use sync DB API
|
||||
)
|
||||
|
||||
# Start with base kwargs that are valid for all pool types
|
||||
final_engine_kwargs: dict[str, Any] = {}
|
||||
|
||||
if POSTGRES_USE_NULL_POOL:
|
||||
# if null pool is specified, then we need to make sure that
|
||||
# we remove any passed in kwargs related to pool size that would
|
||||
# cause the initialization to fail
|
||||
final_engine_kwargs.update(extra_engine_kwargs)
|
||||
|
||||
final_engine_kwargs["poolclass"] = pool.NullPool
|
||||
if "pool_size" in final_engine_kwargs:
|
||||
del final_engine_kwargs["pool_size"]
|
||||
if "max_overflow" in final_engine_kwargs:
|
||||
del final_engine_kwargs["max_overflow"]
|
||||
else:
|
||||
final_engine_kwargs["pool_size"] = pool_size
|
||||
final_engine_kwargs["max_overflow"] = max_overflow
|
||||
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
|
||||
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
|
||||
|
||||
# any passed in kwargs override the defaults
|
||||
final_engine_kwargs.update(extra_engine_kwargs)
|
||||
|
||||
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
|
||||
# echo=True here for inspecting all emitted db queries
|
||||
engine = create_engine(connection_string, **final_engine_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
event.listen(engine, "do_connect", provide_iam_token)
|
||||
|
||||
cls._readonly_engine = engine
|
||||
|
||||
@classmethod
|
||||
def get_engine(cls) -> Engine:
|
||||
if not cls._engine:
|
||||
raise RuntimeError("Engine not initialized. Must call init_engine first.")
|
||||
return cls._engine
|
||||
|
||||
@classmethod
|
||||
def get_readonly_engine(cls) -> Engine:
|
||||
if not cls._readonly_engine:
|
||||
raise RuntimeError(
|
||||
"Readonly engine not initialized. Must call init_readonly_engine first."
|
||||
)
|
||||
return cls._readonly_engine
|
||||
|
||||
@classmethod
|
||||
def set_app_name(cls, app_name: str) -> None:
|
||||
cls._app_name = app_name
|
||||
@@ -307,6 +379,10 @@ def get_sqlalchemy_engine() -> Engine:
|
||||
return SqlEngine.get_engine()
|
||||
|
||||
|
||||
def get_readonly_sqlalchemy_engine() -> Engine:
|
||||
return SqlEngine.get_readonly_engine()
|
||||
|
||||
|
||||
async def get_async_connection() -> Any:
|
||||
"""
|
||||
Custom connection function for async engine when using IAM auth.
|
||||
@@ -560,3 +636,46 @@ def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) ->
|
||||
region = os.getenv("AWS_REGION_NAME", "us-east-2")
|
||||
# Configure for psycopg2 with IAM token
|
||||
configure_psycopg2_iam_auth(cparams, host, port, user, region)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_readonly_user_session_with_current_tenant() -> (
|
||||
Generator[Session, None, None]
|
||||
):
|
||||
"""
|
||||
Generate a database session using a custom database user for the current tenant.
|
||||
The custom user credentials are obtained from environment variables.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
readonly_engine = get_readonly_sqlalchemy_engine()
|
||||
|
||||
event.listen(readonly_engine, "checkout", _set_search_path_on_checkout__listener)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
with readonly_engine.connect() as connection:
|
||||
dbapi_connection = connection.connection
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
cursor.execute(
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
if MULTI_TENANT:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
329
backend/onyx/db/entities.py
Normal file
329
backend/onyx/db/entities.py
Normal file
@@ -0,0 +1,329 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
from typing import List
|
||||
from typing import Type
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import KGEntity
|
||||
from onyx.db.models import KGEntityExtractionStaging
|
||||
from onyx.db.models import KGEntityType
|
||||
from onyx.kg.models import KGGroundingType
|
||||
from onyx.kg.models import KGStage
|
||||
|
||||
|
||||
def add_entity(
|
||||
db_session: Session,
|
||||
kg_stage: KGStage,
|
||||
entity_type: str,
|
||||
name: str,
|
||||
document_id: str | None = None,
|
||||
occurrences: int = 0,
|
||||
event_time: datetime | None = None,
|
||||
attributes: dict[str, str] | None = None,
|
||||
) -> "KGEntity | KGEntityExtractionStaging | None":
|
||||
"""Add a new entity to the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
kg_stage: KGStage of the entity
|
||||
entity_type: Type of the entity (must match an existing KGEntityType)
|
||||
name: Name of the entity
|
||||
occurrences: Number of clusters this entity has been found
|
||||
|
||||
Returns:
|
||||
KGEntity: The created entity
|
||||
"""
|
||||
entity_type = entity_type.upper()
|
||||
name = name.title()
|
||||
id_name = f"{entity_type}:{name}"
|
||||
|
||||
_KGEntityObject: Type[KGEntity | KGEntityExtractionStaging]
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
_KGEntityObject = KGEntityExtractionStaging
|
||||
elif kg_stage == KGStage.NORMALIZED:
|
||||
_KGEntityObject = KGEntity
|
||||
else:
|
||||
raise ValueError(f"Invalid KGStage: {kg_stage}")
|
||||
|
||||
# Create new entity
|
||||
stmt = (
|
||||
pg_insert(_KGEntityObject)
|
||||
.values(
|
||||
id_name=id_name,
|
||||
entity_type_id_name=entity_type,
|
||||
document_id=document_id,
|
||||
name=name,
|
||||
occurrences=occurrences,
|
||||
event_time=event_time,
|
||||
attributes=attributes,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["id_name"],
|
||||
set_=dict(
|
||||
# Direct numeric addition without text()
|
||||
occurrences=_KGEntityObject.occurrences + occurrences,
|
||||
# Keep other fields updated as before
|
||||
entity_type_id_name=entity_type,
|
||||
document_id=document_id,
|
||||
name=name,
|
||||
event_time=event_time,
|
||||
attributes=attributes,
|
||||
),
|
||||
)
|
||||
.returning(_KGEntityObject)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt).scalar()
|
||||
|
||||
# Update the document's kg_stage if document_id is provided
|
||||
if document_id is not None:
|
||||
|
||||
db_session.query(Document).filter(Document.id == document_id).update(
|
||||
{"kg_stage": kg_stage, "kg_processing_time": datetime.now(timezone.utc)}
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_kg_entity_by_document(db: Session, document_id: str) -> KGEntity | None:
|
||||
"""
|
||||
Check if a document_id exists in the kg_entities table and return its id_name if found.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
document_id: The document ID to search for
|
||||
|
||||
Returns:
|
||||
The id_name of the matching KGEntity if found, None otherwise
|
||||
"""
|
||||
query = select(KGEntity).where(KGEntity.document_id == document_id)
|
||||
result = db.execute(query).scalar()
|
||||
return result
|
||||
|
||||
|
||||
def get_entities_by_grounding(
|
||||
db_session: Session, kg_stage: KGStage, grounding: KGGroundingType
|
||||
) -> List[KGEntity] | List[KGEntityExtractionStaging]:
|
||||
"""Get all entities by grounding type.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntity objects for a given grounding type
|
||||
"""
|
||||
|
||||
_KGEntityObject: Type[KGEntity | KGEntityExtractionStaging]
|
||||
|
||||
if kg_stage not in [KGStage.EXTRACTED, KGStage.NORMALIZED]:
|
||||
raise ValueError(f"Invalid KGStage: {kg_stage}")
|
||||
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
_KGEntityObject = KGEntityExtractionStaging
|
||||
elif kg_stage == KGStage.NORMALIZED:
|
||||
_KGEntityObject = KGEntity
|
||||
|
||||
result = list(
|
||||
db_session.query(_KGEntityObject)
|
||||
.join(
|
||||
KGEntityType,
|
||||
_KGEntityObject.entity_type_id_name == KGEntityType.id_name,
|
||||
)
|
||||
.filter(KGEntityType.grounding == grounding)
|
||||
.all()
|
||||
)
|
||||
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
return cast(List[KGEntityExtractionStaging], result)
|
||||
else:
|
||||
return cast(List[KGEntity], result)
|
||||
|
||||
|
||||
def get_grounded_entities_by_types(
|
||||
db_session: Session, entity_types: List[str], grounding: KGGroundingType
|
||||
) -> List[KGEntity]:
|
||||
"""Get all entities matching an entity_type.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity_types: List of entity types to filter by
|
||||
|
||||
Returns:
|
||||
List of KGEntity objects belonging to the specified entity types
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntity)
|
||||
.join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name)
|
||||
.filter(KGEntity.entity_type_id_name.in_(entity_types))
|
||||
.filter(KGEntityType.grounding == grounding)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def delete_entities_by_id_names(
|
||||
db_session: Session, id_names: list[str], kg_stage: KGStage
|
||||
) -> int:
|
||||
"""
|
||||
Delete entities from the database based on a list of id_names.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
id_names: List of entity id_names to delete
|
||||
|
||||
Returns:
|
||||
Number of entities deleted
|
||||
"""
|
||||
|
||||
if kg_stage not in [KGStage.EXTRACTED, KGStage.NORMALIZED]:
|
||||
raise ValueError(f"Invalid KGStage: {kg_stage}")
|
||||
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
_KGEntityObject: Type[KGEntity | KGEntityExtractionStaging] = (
|
||||
KGEntityExtractionStaging
|
||||
)
|
||||
|
||||
elif kg_stage == KGStage.NORMALIZED:
|
||||
_KGEntityObject = KGEntity
|
||||
else:
|
||||
raise ValueError(f"Invalid KGStage: {kg_stage}")
|
||||
|
||||
deleted_count = (
|
||||
db_session.query(_KGEntityObject)
|
||||
.filter(_KGEntityObject.id_name.in_(id_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
db_session.flush() # Flush to ensure deletion is processed
|
||||
return deleted_count
|
||||
|
||||
|
||||
def get_entity_names_for_types(
|
||||
db_session: Session, entity_types: List[str]
|
||||
) -> List[tuple[str, str | None]]:
|
||||
"""Get all entities that belong to the specified entity types.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity_types: List of entity type id_names to filter by
|
||||
|
||||
Returns:
|
||||
List of entity id_names belonging to the specified entity types
|
||||
"""
|
||||
entity_query = db_session.query(KGEntity).filter(
|
||||
KGEntity.entity_type_id_name.in_(entity_types)
|
||||
)
|
||||
|
||||
# Get document IDs from the filtered entities
|
||||
doc_ids = [e.document_id for e in entity_query.all() if e.document_id is not None]
|
||||
|
||||
# Get document info for those IDs
|
||||
doc_info: dict[str, tuple[str | None, str | None]] = {
|
||||
row[0].capitalize(): (row[1], row[2])
|
||||
for row in db_session.query(Document.id, Document.semantic_id, Document.link)
|
||||
.filter(Document.id.in_(doc_ids))
|
||||
.all()
|
||||
}
|
||||
|
||||
# Return entities with their document info
|
||||
|
||||
names: list[tuple[str, str | None]] = []
|
||||
for entity in entity_query.all():
|
||||
|
||||
if entity.document_id is None:
|
||||
names.append((entity.id_name, entity.id_name))
|
||||
continue
|
||||
|
||||
# Extract entity type from the full type ID
|
||||
entity_type = entity.entity_type_id_name.split(":")[0].upper()
|
||||
|
||||
# Get document info, defaulting to None if not found
|
||||
doc_semantic_id = doc_info.get(entity.document_id.capitalize(), (None, None))[0]
|
||||
|
||||
# Construct the final string
|
||||
names.append((entity.id_name, f"{entity_type}:{doc_semantic_id}"))
|
||||
|
||||
return names
|
||||
|
||||
|
||||
def get_entities_by_document_ids(
|
||||
db_session: Session, document_ids: list[str], kg_stage: KGStage
|
||||
) -> List[str]:
|
||||
"""Get all entity id_names that belong to the specified document ids.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
document_ids: List of document ids to filter by
|
||||
|
||||
Returns:
|
||||
List of entity id_names belonging to the specified document ids
|
||||
"""
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
stmt = select(KGEntityExtractionStaging.id_name).where(
|
||||
func.lower(KGEntityExtractionStaging.document_id).in_(document_ids)
|
||||
)
|
||||
elif kg_stage == KGStage.NORMALIZED:
|
||||
stmt = select(KGEntity.id_name).where(
|
||||
func.lower(KGEntity.document_id).in_(document_ids)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid KGStage: {kg_stage.value}")
|
||||
result = db_session.execute(stmt).scalars().all()
|
||||
return list(result)
|
||||
|
||||
|
||||
def get_document_id_for_entity(
|
||||
db_session: Session, entity: str, kg_stage: KGStage = KGStage.NORMALIZED
|
||||
) -> str | None:
|
||||
"""Get the document ID associated with an entity.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
entity: The entity id_name to look up
|
||||
kg_stage: The knowledge graph stage to search in (defaults to NORMALIZED)
|
||||
|
||||
Returns:
|
||||
The document ID if found, None otherwise
|
||||
"""
|
||||
|
||||
entity = entity.replace(": ", ":")
|
||||
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
_KGEntityObject: Type[KGEntity | KGEntityExtractionStaging] = (
|
||||
KGEntityExtractionStaging
|
||||
)
|
||||
elif kg_stage == KGStage.NORMALIZED:
|
||||
_KGEntityObject = KGEntity
|
||||
else:
|
||||
raise ValueError(f"Invalid KGStage: {kg_stage}")
|
||||
|
||||
stmt = select(_KGEntityObject.document_id).where(
|
||||
func.lower(_KGEntityObject.id_name) == func.lower(entity)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt).scalars().first()
|
||||
return result
|
||||
|
||||
|
||||
def delete_from_kg_entities_extraction_staging__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""Delete entities from the extraction staging table."""
|
||||
db_session.query(KGEntityExtractionStaging).filter(
|
||||
KGEntityExtractionStaging.document_id.in_(document_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def delete_from_kg_entities__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""Delete entities from the normalized table."""
|
||||
db_session.query(KGEntity).filter(KGEntity.document_id.in_(document_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
257
backend/onyx/db/entity_type.py
Normal file
257
backend/onyx/db/entity_type.py
Normal file
@@ -0,0 +1,257 @@
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.models import KGEntityType
|
||||
from onyx.kg.kg_default_entity_definitions import KGDefaultAccountEmployeeDefinitions
|
||||
from onyx.kg.kg_default_entity_definitions import (
|
||||
KGDefaultPrimaryGroundedEntityDefinitions,
|
||||
)
|
||||
from onyx.kg.models import KGGroundingType
|
||||
|
||||
|
||||
def get_determined_grounded_entity_types(db_session: Session) -> List[KGEntityType]:
|
||||
"""Get all entity types that have non-null entity_values.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have entity_values defined
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.entity_values.isnot(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_grounded_entity_types(db_session: Session) -> List[KGEntityType]:
|
||||
"""Get all entity types that have grounding = GROUNDED.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have grounding = GROUNDED
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.grounding == KGGroundingType.GROUNDED)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_entity_types_with_grounded_source_name(
|
||||
db_session: Session,
|
||||
) -> List[KGEntityType]:
|
||||
"""Get all entity types that have non-null grounded_source_name.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have grounded_source_name defined
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.grounded_source_name.isnot(None))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_entity_type_by_grounded_source_name(
|
||||
db_session: Session, grounded_source_name: KGGroundingType
|
||||
) -> KGEntityType | None:
|
||||
"""Get an entity type by its grounded_source_name and return it as a dictionary.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
grounded_source_name: The grounded_source_name of the entity to retrieve
|
||||
|
||||
Returns:
|
||||
Dictionary containing the entity's data with column names as keys,
|
||||
or None if the entity is not found
|
||||
"""
|
||||
entity_type = (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.grounded_source_name == grounded_source_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if entity_type is None:
|
||||
return None
|
||||
|
||||
return entity_type
|
||||
|
||||
|
||||
def get_entity_types(
|
||||
db_session: Session,
|
||||
active: bool | None = True,
|
||||
) -> list[KGEntityType]:
|
||||
# Query the database for all distinct entity types
|
||||
|
||||
if active is None:
|
||||
return db_session.query(KGEntityType).order_by(KGEntityType.id_name).all()
|
||||
|
||||
else:
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.active == active)
|
||||
.order_by(KGEntityType.id_name)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def populate_default_primary_grounded_entity_type_information(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Populate the entity type information for the KG.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
"""
|
||||
|
||||
# get kg config information
|
||||
kg_config_settings = get_kg_config_settings(db_session)
|
||||
|
||||
if not kg_config_settings.KG_ENABLED:
|
||||
raise ValueError("KG is not enabled")
|
||||
if not kg_config_settings.KG_VENDOR:
|
||||
raise ValueError("KG_VENDOR is not set")
|
||||
if not kg_config_settings.KG_VENDOR_DOMAINS:
|
||||
raise ValueError("KG_VENDOR_DOMAINS is not set")
|
||||
|
||||
# Get all existing entity types
|
||||
existing_entity_types = {et.id_name for et in db_session.query(KGEntityType).all()}
|
||||
|
||||
# Create an instance of the default definitions
|
||||
default_definitions = KGDefaultPrimaryGroundedEntityDefinitions()
|
||||
|
||||
# Iterate over all attributes in the default definitions
|
||||
for id_name, definition in default_definitions.model_dump().items():
|
||||
# Skip if this entity type already exists
|
||||
if id_name in existing_entity_types:
|
||||
continue
|
||||
|
||||
# Create new entity type
|
||||
|
||||
description = definition["description"].replace(
|
||||
"---vendor_name---", kg_config_settings.KG_VENDOR
|
||||
)
|
||||
|
||||
new_entity_type = KGEntityType(
|
||||
id_name=id_name,
|
||||
description=description,
|
||||
grounding=definition["grounding"],
|
||||
grounded_source_name=definition["grounded_source_name"],
|
||||
active=False,
|
||||
)
|
||||
|
||||
# Add to session
|
||||
db_session.add(new_entity_type)
|
||||
|
||||
# Commit changes
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def populate_default_employee_account_information(db_session: Session) -> None:
|
||||
"""Populate the entity type information for the KG.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
"""
|
||||
|
||||
# get kg config information
|
||||
kg_config_settings = get_kg_config_settings(db_session)
|
||||
|
||||
if not kg_config_settings.KG_ENABLED:
|
||||
raise ValueError("KG is not enabled")
|
||||
if not kg_config_settings.KG_VENDOR:
|
||||
raise ValueError("KG_VENDOR is not set")
|
||||
if not kg_config_settings.KG_VENDOR_DOMAINS:
|
||||
raise ValueError("KG_VENDOR_DOMAINS is not set")
|
||||
|
||||
# Get all existing entity types
|
||||
existing_entity_types = {et.id_name for et in db_session.query(KGEntityType).all()}
|
||||
|
||||
# Create an instance of the default definitions
|
||||
default_definitions = KGDefaultAccountEmployeeDefinitions()
|
||||
|
||||
# Iterate over all attributes in the default definitions
|
||||
for id_name, definition in default_definitions.model_dump().items():
|
||||
# Skip if this entity type already exists
|
||||
if id_name in existing_entity_types:
|
||||
continue
|
||||
|
||||
# Create new entity type
|
||||
description = definition["description"].replace(
|
||||
"---vendor_name---", kg_config_settings.KG_VENDOR
|
||||
)
|
||||
new_entity_type = KGEntityType(
|
||||
id_name=id_name,
|
||||
description=description,
|
||||
grounding=definition["grounding"],
|
||||
grounded_source_name=definition["grounded_source_name"],
|
||||
active=definition["active"],
|
||||
)
|
||||
|
||||
# Add to session
|
||||
db_session.add(new_entity_type)
|
||||
|
||||
# Commit changes
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def get_grounded_entity_types_with_null_grounded_source(
|
||||
db_session: Session,
|
||||
) -> List[KGEntityType]:
|
||||
"""Get all entity types that have null grounded_source_name and grounding = GROUNDED.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have null grounded_source_name and grounding = GROUNDED
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.grounded_source_name.is_(None))
|
||||
.filter(KGEntityType.grounding == KGGroundingType.GROUNDED)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_entity_types_by_grounding(
|
||||
db_session: Session,
|
||||
grounding: KGGroundingType,
|
||||
) -> List[KGEntityType]:
|
||||
"""Get all entity types that have a specific grounding.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
grounding: The grounding type to filter by
|
||||
|
||||
Returns:
|
||||
List of KGEntityType objects that have the specified grounding
|
||||
"""
|
||||
return (
|
||||
db_session.query(KGEntityType).filter(KGEntityType.grounding == grounding).all()
|
||||
)
|
||||
|
||||
|
||||
def get_grounded_source_name(db_session: Session, entity_type: str) -> str | None:
|
||||
"""
|
||||
Get the grounded source name for an entity type.
|
||||
"""
|
||||
|
||||
result = (
|
||||
db_session.query(KGEntityType)
|
||||
.filter(KGEntityType.id_name == entity_type)
|
||||
.first()
|
||||
)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return result.grounded_source_name
|
||||
34
backend/onyx/db/kg_config.py
Normal file
34
backend/onyx/db/kg_config.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import KGConfig
|
||||
from onyx.kg.models import KGConfigSettings
|
||||
from onyx.kg.models import KGConfigVars
|
||||
|
||||
|
||||
def get_kg_enablement(db_session: Session) -> bool:
|
||||
check = (
|
||||
db_session.query(KGConfig.kg_variable_values)
|
||||
.filter(
|
||||
KGConfig.kg_variable_name == "KG_ENABLED"
|
||||
and KGConfig.kg_variable_values == ["true"]
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return check is not None
|
||||
|
||||
|
||||
def get_kg_config_settings(db_session: Session) -> KGConfigSettings:
|
||||
results = db_session.query(KGConfig).all()
|
||||
|
||||
kg_config_settings = KGConfigSettings()
|
||||
for result in results:
|
||||
if result.kg_variable_name == "KG_ENABLED":
|
||||
kg_config_settings.KG_ENABLED = result.kg_variable_values[0] == "true"
|
||||
elif result.kg_variable_name == KGConfigVars.KG_VENDOR:
|
||||
kg_config_settings.KG_VENDOR = result.kg_variable_values[0]
|
||||
elif result.kg_variable_name == KGConfigVars.KG_VENDOR_DOMAINS:
|
||||
kg_config_settings.KG_VENDOR_DOMAINS = result.kg_variable_values
|
||||
elif result.kg_variable_name == KGConfigVars.KG_IGNORE_EMAIL_DOMAINS:
|
||||
kg_config_settings.KG_IGNORE_EMAIL_DOMAINS = result.kg_variable_values
|
||||
|
||||
return kg_config_settings
|
||||
131
backend/onyx/db/kg_temp_view.py
Normal file
131
backend/onyx/db/kg_temp_view.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
# First, create the view definition
|
||||
def create_views(
|
||||
db_session: Session,
|
||||
user_email: str,
|
||||
allowed_docs_view_name: str = "allowed_docs",
|
||||
kg_relationships_view_name: str = "kg_relationships_with_access",
|
||||
) -> None:
|
||||
# Create ALLOWED_DOCS view
|
||||
allowed_docs_view = text(
|
||||
f"""
|
||||
CREATE OR REPLACE VIEW {allowed_docs_view_name} AS
|
||||
WITH public_docs AS (
|
||||
SELECT d.id as allowed_doc_id
|
||||
FROM document_by_connector_credential_pair d
|
||||
JOIN connector_credential_pair ccp ON
|
||||
d.connector_id = ccp.connector_id AND
|
||||
d.credential_id = ccp.credential_id
|
||||
WHERE ccp.status != 'DELETING'
|
||||
AND ccp.access_type = 'PUBLIC'
|
||||
),
|
||||
user_owned_docs AS (
|
||||
SELECT d.id as allowed_doc_id
|
||||
FROM document_by_connector_credential_pair d
|
||||
JOIN credential c ON d.credential_id = c.id
|
||||
JOIN connector_credential_pair ccp ON
|
||||
d.connector_id = ccp.connector_id AND
|
||||
d.credential_id = ccp.credential_id
|
||||
JOIN "user" u ON c.user_id = u.id
|
||||
WHERE ccp.status != 'DELETING'
|
||||
AND ccp.access_type != 'SYNC'
|
||||
AND u.email = :user_email
|
||||
),
|
||||
external_user_docs AS (
|
||||
SELECT id as allowed_doc_id
|
||||
FROM document
|
||||
WHERE :user_email = ANY(external_user_emails)
|
||||
),
|
||||
external_group_docs AS (
|
||||
SELECT d.id as allowed_doc_id
|
||||
FROM document d
|
||||
JOIN user__external_user_group_id ueg ON ueg.external_user_group_id = ANY(d.external_user_group_ids)
|
||||
JOIN "user" u ON ueg.user_id = u.id
|
||||
WHERE u.email = :user_email
|
||||
)
|
||||
SELECT DISTINCT allowed_doc_id FROM (
|
||||
SELECT allowed_doc_id FROM public_docs
|
||||
UNION
|
||||
SELECT allowed_doc_id FROM user_owned_docs
|
||||
UNION
|
||||
SELECT allowed_doc_id FROM external_user_docs
|
||||
UNION
|
||||
SELECT allowed_doc_id FROM external_group_docs
|
||||
) combined_docs
|
||||
"""
|
||||
).bindparams(user_email=user_email)
|
||||
|
||||
# Create the main view that uses ALLOWED_DOCS
|
||||
kg_relationships_view = text(
|
||||
f"""
|
||||
CREATE OR REPLACE VIEW {kg_relationships_view_name} AS
|
||||
SELECT kgr.id_name as relationship,
|
||||
kgr.source_node as source_entity,
|
||||
kgr.target_node as target_entity,
|
||||
kgr.source_node_type as source_entity_type,
|
||||
kgr.target_node_type as target_entity_type,
|
||||
kgr.type as relationship_description,
|
||||
kgr.relationship_type_id_name as relationship_type,
|
||||
kgr.source_document as source_document,
|
||||
d.doc_updated_at as source_date,
|
||||
se.attributes as source_entity_attributes,
|
||||
te.attributes as target_entity_attributes
|
||||
FROM kg_relationship kgr
|
||||
INNER JOIN {allowed_docs_view_name} AD on AD.allowed_doc_id = kgr.source_document
|
||||
JOIN document d on d.id = kgr.source_document
|
||||
JOIN kg_entity se on se.id_name = kgr.source_node
|
||||
JOIN kg_entity te on te.id_name = kgr.target_node
|
||||
"""
|
||||
)
|
||||
|
||||
# Execute the views using the session
|
||||
db_session.execute(allowed_docs_view)
|
||||
db_session.execute(kg_relationships_view)
|
||||
|
||||
grant_kg_relationships = text(
|
||||
f"GRANT SELECT ON {kg_relationships_view_name} TO {DB_READONLY_USER}"
|
||||
)
|
||||
db_session.execute(grant_kg_relationships)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def drop_views(
|
||||
db_session: Session,
|
||||
allowed_docs_view_name: str = "allowed_docs",
|
||||
kg_relationships_view_name: str = "kg_relationships_with_access",
|
||||
) -> None:
|
||||
"""
|
||||
Drops the temporary views created by create_views.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
allowed_docs_view_name: Name of the allowed_docs view
|
||||
kg_relationships_view_name: Name of the kg_relationships view
|
||||
"""
|
||||
# First revoke access from the readonly user
|
||||
revoke_kg_relationships = text(
|
||||
f"REVOKE SELECT ON {kg_relationships_view_name} FROM {DB_READONLY_USER}"
|
||||
)
|
||||
|
||||
db_session.execute(revoke_kg_relationships)
|
||||
|
||||
# Drop the views in reverse order of creation to handle dependencies
|
||||
drop_kg_relationships = text(f"DROP VIEW IF EXISTS {kg_relationships_view_name}")
|
||||
drop_allowed_docs = text(f"DROP VIEW IF EXISTS {allowed_docs_view_name}")
|
||||
|
||||
db_session.execute(drop_kg_relationships)
|
||||
db_session.execute(drop_allowed_docs)
|
||||
db_session.commit()
|
||||
return None
|
||||
@@ -39,6 +39,7 @@ from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.types import LargeBinary
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
@@ -69,6 +70,7 @@ from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.kg.models import KGStage
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
@@ -586,6 +588,16 @@ class Document(Base):
|
||||
)
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# tables for the knowledge graph data
|
||||
kg_stage: Mapped[KGStage] = mapped_column(
|
||||
Enum(KGStage, native_enum=False),
|
||||
comment="Status of knowledge graph extraction for this document",
|
||||
)
|
||||
|
||||
kg_processing_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="document"
|
||||
)
|
||||
@@ -604,12 +616,589 @@ class Document(Base):
|
||||
)
|
||||
|
||||
|
||||
class KGConfig(Base):
|
||||
__tablename__ = "kg_config"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
|
||||
kg_variable_name: Mapped[str] = mapped_column(NullFilteredString, nullable=False)
|
||||
kg_variable_values: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
|
||||
class KGEntityType(Base):
|
||||
__tablename__ = "kg_entity_type"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, nullable=False, index=True
|
||||
)
|
||||
|
||||
description: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
|
||||
|
||||
grounding: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=False, index=False
|
||||
)
|
||||
|
||||
attributes: Mapped[str] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=True,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Filtering based on document attribute",
|
||||
)
|
||||
|
||||
occurrences: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
deep_extraction: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
grounded_source_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString, nullable=False, index=False
|
||||
)
|
||||
|
||||
entity_values: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True, default=None
|
||||
)
|
||||
|
||||
clustering: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Clustering information for this entity type",
|
||||
)
|
||||
|
||||
|
||||
class KGRelationshipType(Base):
|
||||
__tablename__ = "kg_relationship_type"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
source_entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
target_entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
definition: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
comment="Whether this relationship type represents a definition",
|
||||
)
|
||||
|
||||
clustering: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Clustering information for this relationship type",
|
||||
)
|
||||
|
||||
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
occurrences: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# Relationships to EntityType
|
||||
source_type: Mapped["KGEntityType"] = relationship(
|
||||
"KGEntityType",
|
||||
foreign_keys=[source_entity_type_id_name],
|
||||
backref="source_relationship_type",
|
||||
)
|
||||
target_type: Mapped["KGEntityType"] = relationship(
|
||||
"KGEntityType",
|
||||
foreign_keys=[target_entity_type_id_name],
|
||||
backref="target_relationship_type",
|
||||
)
|
||||
|
||||
|
||||
class KGRelationshipTypeExtractionStaging(Base):
|
||||
__tablename__ = "kg_relationship_type_extraction_staging"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
source_entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
target_entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
definition: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
comment="Whether this relationship type represents a definition",
|
||||
)
|
||||
|
||||
clustering: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Clustering information for this relationship type",
|
||||
)
|
||||
|
||||
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
occurrences: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# Relationships to EntityType
|
||||
source_type: Mapped["KGEntityType"] = relationship(
|
||||
"KGEntityType",
|
||||
foreign_keys=[source_entity_type_id_name],
|
||||
backref="source_relationship_type_staging",
|
||||
)
|
||||
target_type: Mapped["KGEntityType"] = relationship(
|
||||
"KGEntityType",
|
||||
foreign_keys=[target_entity_type_id_name],
|
||||
backref="target_relationship_type_staging",
|
||||
)
|
||||
|
||||
|
||||
class KGEntity(Base):
|
||||
__tablename__ = "kg_entity"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString, primary_key=True, index=True
|
||||
)
|
||||
|
||||
# Basic entity information
|
||||
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
attributes: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Attributes for this entity",
|
||||
)
|
||||
|
||||
document_id: Mapped[str | None] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=True
|
||||
)
|
||||
|
||||
alternative_names: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Reference to KGEntityType
|
||||
entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationship to KGEntityType
|
||||
entity_type: Mapped["KGEntityType"] = relationship("KGEntityType", backref="entity")
|
||||
|
||||
description: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
keywords: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
occurrences: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Access control
|
||||
acl: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Boosts - using JSON for flexibility
|
||||
boosts: Mapped[dict] = mapped_column(postgresql.JSONB, nullable=False, default=dict)
|
||||
|
||||
event_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
comment="Time of the event being processed",
|
||||
)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Fixed column names in indexes
|
||||
Index("ix_entity_type_acl", entity_type_id_name, acl),
|
||||
Index("ix_entity_name_search", name, entity_type_id_name),
|
||||
)
|
||||
|
||||
|
||||
class KGEntityExtractionStaging(Base):
|
||||
__tablename__ = "kg_entity_extraction_staging"
|
||||
|
||||
# Primary identifier
|
||||
id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString, primary_key=True, index=True
|
||||
)
|
||||
|
||||
# Basic entity information
|
||||
name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
attributes: Mapped[dict] = mapped_column(
|
||||
postgresql.JSONB,
|
||||
nullable=False,
|
||||
default=dict,
|
||||
server_default="{}",
|
||||
comment="Attributes for this entity",
|
||||
)
|
||||
|
||||
document_id: Mapped[str | None] = mapped_column(
|
||||
NullFilteredString, nullable=True, index=True
|
||||
)
|
||||
|
||||
alternative_names: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Reference to KGEntityType
|
||||
entity_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationship to KGEntityType
|
||||
entity_type: Mapped["KGEntityType"] = relationship(
|
||||
"KGEntityType", backref="entity_staging"
|
||||
)
|
||||
|
||||
description: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
keywords: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
occurrences: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Access control
|
||||
acl: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Boosts - using JSON for flexibility
|
||||
boosts: Mapped[dict] = mapped_column(postgresql.JSONB, nullable=False, default=dict)
|
||||
|
||||
event_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
comment="Time of the event being processed",
|
||||
)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Fixed column names in indexes
|
||||
Index("ix_entity_type_acl", entity_type_id_name, acl),
|
||||
Index("ix_entity_name_search", name, entity_type_id_name),
|
||||
)
|
||||
|
||||
|
||||
class KGRelationship(Base):
|
||||
__tablename__ = "kg_relationship"
|
||||
|
||||
# Primary identifier - now part of composite key
|
||||
id_name: Mapped[str] = mapped_column(NullFilteredString, index=True)
|
||||
|
||||
source_document: Mapped[str | None] = mapped_column(
|
||||
NullFilteredString, ForeignKey("document.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
# Source and target nodes (foreign keys to Entity table)
|
||||
source_node: Mapped[str] = mapped_column(
|
||||
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
|
||||
)
|
||||
|
||||
target_node: Mapped[str] = mapped_column(
|
||||
NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True
|
||||
)
|
||||
|
||||
source_node_type: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
target_node_type: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationship type
|
||||
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
# Add new relationship type reference
|
||||
relationship_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_relationship_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Add the SQLAlchemy relationship property
|
||||
relationship_type: Mapped["KGRelationshipType"] = relationship(
|
||||
"KGRelationshipType", backref="relationship"
|
||||
)
|
||||
|
||||
occurrences: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# Relationships to Entity table
|
||||
source: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[source_node])
|
||||
target: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[target_node])
|
||||
document: Mapped["Document"] = relationship(
|
||||
"Document", foreign_keys=[source_document]
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Composite primary key
|
||||
PrimaryKeyConstraint("id_name", "source_document"),
|
||||
# Index for querying relationships by type
|
||||
Index("ix_kg_relationship_type", type),
|
||||
# Composite index for source/target queries
|
||||
Index("ix_kg_relationship_nodes", source_node, target_node),
|
||||
# Ensure unique relationships between nodes of a specific type
|
||||
UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_source_target_type",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class KGRelationshipExtractionStaging(Base):
|
||||
__tablename__ = "kg_relationship_extraction_staging"
|
||||
|
||||
# Primary identifier - now part of composite key
|
||||
id_name: Mapped[str] = mapped_column(NullFilteredString, index=True)
|
||||
|
||||
source_document: Mapped[str | None] = mapped_column(
|
||||
NullFilteredString, ForeignKey("document.id"), nullable=True, index=True
|
||||
)
|
||||
|
||||
# Source and target nodes (foreign keys to Entity table)
|
||||
source_node: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_extraction_staging.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
target_node: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_extraction_staging.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
source_node_type: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
target_node_type: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_entity_type.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationship type
|
||||
type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True)
|
||||
|
||||
# Add new relationship type reference
|
||||
relationship_type_id_name: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
ForeignKey("kg_relationship_type_extraction_staging.id_name"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Add the SQLAlchemy relationship property
|
||||
relationship_type: Mapped["KGRelationshipTypeExtractionStaging"] = relationship(
|
||||
"KGRelationshipTypeExtractionStaging", backref="relationship_staging"
|
||||
)
|
||||
|
||||
occurrences: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# Relationships to Entity table
|
||||
source: Mapped["KGEntityExtractionStaging"] = relationship(
|
||||
"KGEntityExtractionStaging", foreign_keys=[source_node]
|
||||
)
|
||||
target: Mapped["KGEntityExtractionStaging"] = relationship(
|
||||
"KGEntityExtractionStaging", foreign_keys=[target_node]
|
||||
)
|
||||
document: Mapped["Document"] = relationship(
|
||||
"Document", foreign_keys=[source_document]
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Composite primary key
|
||||
PrimaryKeyConstraint("id_name", "source_document"),
|
||||
# Index for querying relationships by type
|
||||
Index("ix_kg_relationship_type", type),
|
||||
# Composite index for source/target queries
|
||||
Index("ix_kg_relationship_nodes", source_node, target_node),
|
||||
# Ensure unique relationships between nodes of a specific type
|
||||
UniqueConstraint(
|
||||
"source_node",
|
||||
"target_node",
|
||||
"type",
|
||||
name="uq_kg_relationship_source_target_type",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class KGTerm(Base):
|
||||
__tablename__ = "kg_term"
|
||||
|
||||
# Make id_term the primary key
|
||||
id_term: Mapped[str] = mapped_column(
|
||||
NullFilteredString, primary_key=True, nullable=False, index=True
|
||||
)
|
||||
|
||||
# List of entity types this term applies to
|
||||
entity_types: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=False, default=list
|
||||
)
|
||||
|
||||
# Tracking fields
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Index for searching terms with specific entity types
|
||||
Index("ix_search_term_entities", entity_types),
|
||||
# Index for term lookups
|
||||
Index("ix_search_term_term", id_term),
|
||||
)
|
||||
|
||||
|
||||
class ChunkStats(Base):
|
||||
__tablename__ = "chunk_stats"
|
||||
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
|
||||
|
||||
# this should correspond to the ID of the document
|
||||
# (as is passed around in Onyx)
|
||||
# (as is passed around in Onyx)x
|
||||
id: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
primary_key=True,
|
||||
@@ -691,6 +1280,14 @@ class Connector(Base):
|
||||
indexing_start: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime, nullable=True
|
||||
)
|
||||
|
||||
kg_processing_enabled: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
comment="Whether this connector should extract knowledge graph entities",
|
||||
)
|
||||
|
||||
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
|
||||
523
backend/onyx/db/relationships.py
Normal file
523
backend/onyx/db/relationships.py
Normal file
@@ -0,0 +1,523 @@
|
||||
from typing import cast
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import KGEntity
|
||||
from onyx.db.models import KGEntityExtractionStaging
|
||||
from onyx.db.models import KGRelationship
|
||||
from onyx.db.models import KGRelationshipExtractionStaging
|
||||
from onyx.db.models import KGRelationshipType
|
||||
from onyx.db.models import KGRelationshipTypeExtractionStaging
|
||||
from onyx.db.models import KGStage
|
||||
from onyx.kg.utils.formatting_utils import format_entity
|
||||
from onyx.kg.utils.formatting_utils import format_relationship
|
||||
from onyx.kg.utils.formatting_utils import generate_relationship_type
|
||||
|
||||
|
||||
def add_relationship(
|
||||
db_session: Session,
|
||||
kg_stage: KGStage,
|
||||
relationship_id_name: str,
|
||||
source_document_id: str,
|
||||
occurrences: int | None = None,
|
||||
) -> Union["KGRelationship", "KGRelationshipExtractionStaging"]:
|
||||
"""
|
||||
Add a relationship between two entities to the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
relationship_type: Type of relationship
|
||||
source_document_id: ID of the source document
|
||||
occurrences: Optional count of similar relationships clustered together
|
||||
|
||||
Returns:
|
||||
The created KGRelationship object
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.IntegrityError: If the relationship already exists or entities don't exist
|
||||
"""
|
||||
# Generate a unique ID for the relationship
|
||||
|
||||
(
|
||||
source_entity_id_name,
|
||||
relationship_string,
|
||||
target_entity_id_name,
|
||||
) = relationship_id_name.split("__")
|
||||
|
||||
source_entity_id_name = format_entity(source_entity_id_name)
|
||||
source_entity_type = source_entity_id_name.split(":")[0]
|
||||
target_entity_id_name = format_entity(target_entity_id_name)
|
||||
target_entity_type = target_entity_id_name.split(":")[0]
|
||||
relationship_id_name = format_relationship(relationship_id_name)
|
||||
relationship_type = generate_relationship_type(relationship_id_name)
|
||||
|
||||
relationship_data = {
|
||||
"id_name": relationship_id_name,
|
||||
"source_node": source_entity_id_name,
|
||||
"target_node": target_entity_id_name,
|
||||
"source_node_type": source_entity_type,
|
||||
"target_node_type": target_entity_type,
|
||||
"type": relationship_string.lower(),
|
||||
"relationship_type_id_name": relationship_type,
|
||||
"source_document": source_document_id,
|
||||
"occurrences": occurrences or 1,
|
||||
}
|
||||
|
||||
relationship: KGRelationship | KGRelationshipExtractionStaging
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
relationship = KGRelationshipExtractionStaging(**relationship_data)
|
||||
elif kg_stage == KGStage.NORMALIZED:
|
||||
relationship = KGRelationship(**relationship_data)
|
||||
else:
|
||||
raise ValueError(f"Invalid kg_stage: {kg_stage}")
|
||||
|
||||
# Use on_conflict_do_update to handle conflicts
|
||||
stmt = (
|
||||
postgresql.insert(type(relationship))
|
||||
.values(**relationship_data)
|
||||
.on_conflict_do_update(
|
||||
constraint=(
|
||||
"kg_relationship_pkey"
|
||||
if kg_stage == KGStage.NORMALIZED
|
||||
else "kg_relationship_extraction_staging_pkey"
|
||||
),
|
||||
set_={
|
||||
"occurrences": int(str(relationship_data["occurrences"] or 0))
|
||||
+ (occurrences or 1),
|
||||
"time_updated": func.now(),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
db_session.execute(stmt)
|
||||
db_session.flush() # Flush to get any DB errors early
|
||||
|
||||
# Fetch the updated/inserted record
|
||||
result: Union[KGRelationship, KGRelationshipExtractionStaging, None] = None
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
result = (
|
||||
db_session.query(KGRelationshipExtractionStaging)
|
||||
.filter_by(id_name=relationship_id_name, source_document=source_document_id)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
result = (
|
||||
db_session.query(KGRelationship)
|
||||
.filter_by(id_name=relationship_id_name, source_document=source_document_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise ValueError(
|
||||
f"Failed to create or update relationship with id_name: {relationship_id_name}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_or_increment_relationship(
|
||||
db_session: Session,
|
||||
kg_stage: KGStage,
|
||||
relationship_id_name: str,
|
||||
source_document_id: str,
|
||||
new_occurrences: int = 1,
|
||||
) -> KGRelationship | KGRelationshipExtractionStaging:
|
||||
"""
|
||||
Add a relationship between two entities to the database if it doesn't exist,
|
||||
or increment its occurrences by 1 if it already exists.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
relationship_id_name: The ID name of the relationship in format "source__relationship__target"
|
||||
source_document_id: ID of the source document
|
||||
Returns:
|
||||
The created or updated KGRelationship object
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.IntegrityError: If there's an error with the database operation
|
||||
"""
|
||||
# Format the relationship_id_name
|
||||
relationship_id_name = format_relationship(relationship_id_name)
|
||||
|
||||
_KGTable: type[KGRelationship] | type[KGRelationshipExtractionStaging]
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
_KGTable = KGRelationshipExtractionStaging
|
||||
elif kg_stage == KGStage.NORMALIZED:
|
||||
_KGTable = KGRelationship
|
||||
else:
|
||||
raise ValueError(f"Invalid kg_stage: {kg_stage}")
|
||||
|
||||
# Check if the relationship already exists
|
||||
existing_relationship = (
|
||||
db_session.query(_KGTable)
|
||||
.filter(_KGTable.id_name == relationship_id_name)
|
||||
.filter(_KGTable.source_document == source_document_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_relationship:
|
||||
# If it exists, increment the occurrences
|
||||
existing_relationship = cast(
|
||||
KGRelationship | KGRelationshipExtractionStaging, existing_relationship
|
||||
)
|
||||
existing_relationship.occurrences = (
|
||||
existing_relationship.occurrences or 0
|
||||
) + new_occurrences
|
||||
db_session.flush()
|
||||
return existing_relationship
|
||||
else:
|
||||
# If it doesn't exist, add it with occurrences=1
|
||||
db_session.flush()
|
||||
return add_relationship(
|
||||
db_session,
|
||||
KGStage(kg_stage),
|
||||
relationship_id_name,
|
||||
source_document_id,
|
||||
occurrences=new_occurrences,
|
||||
)
|
||||
|
||||
|
||||
def add_relationship_type(
|
||||
db_session: Session,
|
||||
kg_stage: KGStage,
|
||||
source_entity_type: str,
|
||||
relationship_type: str,
|
||||
target_entity_type: str,
|
||||
definition: bool = False,
|
||||
extraction_count: int = 0,
|
||||
) -> str:
|
||||
"""
|
||||
Add a new relationship type to the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
source_entity_type: Type of the source entity
|
||||
relationship_type: Type of relationship
|
||||
target_entity_type: Type of the target entity
|
||||
definition: Whether this relationship type represents a definition (default False)
|
||||
|
||||
Returns:
|
||||
The created KGRelationshipType object
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.IntegrityError: If the relationship type already exists
|
||||
"""
|
||||
|
||||
id_name = f"{source_entity_type.upper()}__{relationship_type}__{target_entity_type.upper()}"
|
||||
# Create new relationship type
|
||||
|
||||
relationship_data = {
|
||||
"id_name": id_name,
|
||||
"name": relationship_type,
|
||||
"source_entity_type_id_name": source_entity_type.upper(),
|
||||
"target_entity_type_id_name": target_entity_type.upper(),
|
||||
"definition": definition,
|
||||
"occurrences": extraction_count,
|
||||
"type": relationship_type, # Using the relationship_type as the type
|
||||
"active": True, # Setting as active by default
|
||||
}
|
||||
|
||||
rel_type: KGRelationshipType | KGRelationshipTypeExtractionStaging
|
||||
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
rel_type = KGRelationshipTypeExtractionStaging(**relationship_data)
|
||||
elif kg_stage == KGStage.NORMALIZED:
|
||||
rel_type = KGRelationshipType(**relationship_data)
|
||||
else:
|
||||
raise ValueError(f"Invalid kg_stage: {kg_stage}")
|
||||
|
||||
# Use on_conflict_do_update to handle conflicts
|
||||
stmt = (
|
||||
postgresql.insert(type(rel_type))
|
||||
.values(**relationship_data)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["id_name"],
|
||||
set_={
|
||||
"name": relationship_data["name"],
|
||||
"source_entity_type_id_name": relationship_data[
|
||||
"source_entity_type_id_name"
|
||||
],
|
||||
"target_entity_type_id_name": relationship_data[
|
||||
"target_entity_type_id_name"
|
||||
],
|
||||
"definition": relationship_data["definition"],
|
||||
"occurrences": int(str(relationship_data["occurrences"] or 0))
|
||||
+ extraction_count,
|
||||
"type": relationship_data["type"],
|
||||
"active": relationship_data["active"],
|
||||
"time_updated": func.now(),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
db_session.execute(stmt)
|
||||
db_session.flush() # Flush to get any DB errors early
|
||||
|
||||
return id_name
|
||||
|
||||
|
||||
def get_all_relationship_types(
|
||||
db_session: Session, kg_stage: str
|
||||
) -> list["KGRelationshipType"] | list["KGRelationshipTypeExtractionStaging"]:
|
||||
"""
|
||||
Retrieve all relationship types from the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
List of KGRelationshipType or KGRelationshipTypeExtractionStaging objects
|
||||
"""
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
return db_session.query(KGRelationshipTypeExtractionStaging).all()
|
||||
elif kg_stage == KGStage.NORMALIZED:
|
||||
return db_session.query(KGRelationshipType).all()
|
||||
else:
|
||||
raise ValueError(f"Invalid kg_stage: {kg_stage}")
|
||||
|
||||
|
||||
def get_all_relationships(
|
||||
db_session: Session, kg_stage: KGStage
|
||||
) -> list["KGRelationship"] | list["KGRelationshipExtractionStaging"]:
|
||||
"""
|
||||
Retrieve all relationships from the database.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
|
||||
Returns:
|
||||
List of KGRelationship objects
|
||||
"""
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
return db_session.query(KGRelationshipExtractionStaging).all()
|
||||
elif kg_stage == KGStage.NORMALIZED:
|
||||
return db_session.query(KGRelationship).all()
|
||||
else:
|
||||
raise ValueError(f"Invalid kg_stage: {kg_stage}")
|
||||
|
||||
|
||||
def delete_relationships_by_id_names(
|
||||
db_session: Session, id_names: list[str], kg_stage: KGStage
|
||||
) -> int:
|
||||
"""
|
||||
Delete relationships from the database based on a list of id_names.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
id_names: List of relationship id_names to delete
|
||||
|
||||
Returns:
|
||||
Number of relationships deleted
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
|
||||
"""
|
||||
|
||||
deleted_count = 0
|
||||
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
deleted_count = (
|
||||
db_session.query(KGRelationshipExtractionStaging)
|
||||
.filter(KGRelationshipExtractionStaging.id_name.in_(id_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
elif kg_stage == KGStage.NORMALIZED:
|
||||
deleted_count = (
|
||||
db_session.query(KGRelationship)
|
||||
.filter(KGRelationship.id_name.in_(id_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
db_session.flush() # Flush to ensure deletion is processed
|
||||
return deleted_count
|
||||
|
||||
|
||||
def delete_relationship_types_by_id_names(
|
||||
db_session: Session, id_names: list[str], kg_stage: KGStage
|
||||
) -> int:
|
||||
"""
|
||||
Delete relationship types from the database based on a list of id_names.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
id_names: List of relationship type id_names to delete
|
||||
|
||||
Returns:
|
||||
Number of relationship types deleted
|
||||
|
||||
Raises:
|
||||
sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion
|
||||
"""
|
||||
deleted_count = 0
|
||||
|
||||
if kg_stage == KGStage.EXTRACTED:
|
||||
deleted_count = (
|
||||
db_session.query(KGRelationshipTypeExtractionStaging)
|
||||
.filter(KGRelationshipTypeExtractionStaging.id_name.in_(id_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
elif kg_stage == KGStage.NORMALIZED:
|
||||
deleted_count = (
|
||||
db_session.query(KGRelationshipType)
|
||||
.filter(KGRelationshipType.id_name.in_(id_names))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
db_session.flush() # Flush to ensure deletion is processed
|
||||
return deleted_count
|
||||
|
||||
|
||||
def get_relationships_for_entity_type_pairs(
|
||||
db_session: Session, entity_type_pairs: list[tuple[str, str]]
|
||||
) -> list["KGRelationshipType"]:
|
||||
"""
|
||||
Get relationship types from the database based on a list of entity type pairs.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
entity_type_pairs: List of tuples where each tuple contains (source_entity_type, target_entity_type)
|
||||
|
||||
Returns:
|
||||
List of KGRelationshipType objects where source and target types match the provided pairs
|
||||
"""
|
||||
|
||||
conditions = [
|
||||
(
|
||||
(KGRelationshipType.source_entity_type_id_name == source_type)
|
||||
& (KGRelationshipType.target_entity_type_id_name == target_type)
|
||||
)
|
||||
for source_type, target_type in entity_type_pairs
|
||||
]
|
||||
|
||||
return db_session.query(KGRelationshipType).filter(or_(*conditions)).all()
|
||||
|
||||
|
||||
def get_allowed_relationship_type_pairs(
|
||||
db_session: Session, entities: list[str]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get the allowed relationship pairs for the given entities.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy database session
|
||||
entities: List of entity type ID names to filter by
|
||||
|
||||
Returns:
|
||||
List of id_names from KGRelationshipType where both source and target entity types
|
||||
are in the provided entities list
|
||||
"""
|
||||
entity_types = list(set([entity.split(":")[0] for entity in entities]))
|
||||
|
||||
return [
|
||||
row[0]
|
||||
for row in (
|
||||
db_session.query(KGRelationshipType.id_name)
|
||||
.filter(KGRelationshipType.source_entity_type_id_name.in_(entity_types))
|
||||
.filter(KGRelationshipType.target_entity_type_id_name.in_(entity_types))
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def get_relationships_of_entity(db_session: Session, entity_id: str) -> List[str]:
|
||||
"""Get all relationship ID names where the given entity is either the source or target node.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity_id: ID of the entity to find relationships for
|
||||
|
||||
Returns:
|
||||
List of relationship ID names where the entity is either source or target
|
||||
"""
|
||||
return [
|
||||
row[0]
|
||||
for row in (
|
||||
db_session.query(KGRelationship.id_name)
|
||||
.filter(
|
||||
or_(
|
||||
KGRelationship.source_node == entity_id,
|
||||
KGRelationship.target_node == entity_id,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def get_relationship_types_of_entity_types(
|
||||
db_session: Session, entity_types_id: str
|
||||
) -> List[str]:
|
||||
"""Get all relationship ID names where the given entity is either the source or target node.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
entity_types_id: ID of the entity to find relationships for
|
||||
|
||||
Returns:
|
||||
List of relationship ID names where the entity is either source or target
|
||||
"""
|
||||
|
||||
if entity_types_id.endswith(":*"):
|
||||
entity_types_id = entity_types_id[:-2]
|
||||
|
||||
return [
|
||||
row[0]
|
||||
for row in (
|
||||
db_session.query(KGRelationshipType.id_name)
|
||||
.filter(
|
||||
or_(
|
||||
KGRelationshipType.source_entity_type_id_name == entity_types_id,
|
||||
KGRelationshipType.target_entity_type_id_name == entity_types_id,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def delete_document_references_from_kg(db_session: Session, document_id: str) -> None:
|
||||
# Delete relationships from normalized stage
|
||||
db_session.query(KGRelationship).filter(
|
||||
KGRelationship.source_document == document_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
# Delete relationships from extraction staging
|
||||
db_session.query(KGRelationshipExtractionStaging).filter(
|
||||
KGRelationshipExtractionStaging.source_document == document_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
# Delete entities from normalized stage
|
||||
db_session.query(KGEntity).filter(KGEntity.document_id == document_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
# Delete entities from extraction staging
|
||||
db_session.query(KGEntityExtractionStaging).filter(
|
||||
KGEntityExtractionStaging.document_id == document_id
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def delete_from_kg_relationships_extraction_staging__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""Delete relationships from the extraction staging table."""
|
||||
db_session.query(KGRelationshipExtractionStaging).filter(
|
||||
KGRelationshipExtractionStaging.source_document.in_(document_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
|
||||
def delete_from_kg_relationships__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""Delete relationships from the normalized table."""
|
||||
db_session.query(KGRelationship).filter(
|
||||
KGRelationship.source_document.in_(document_ids)
|
||||
).delete(synchronize_session=False)
|
||||
@@ -91,6 +91,25 @@ schema {{ schema_name }} {
|
||||
indexing: attribute
|
||||
}
|
||||
|
||||
# Separate array fields for knowledge graph data
|
||||
field kg_entities type weightedset<string> {
|
||||
rank: filter
|
||||
indexing: summary | attribute
|
||||
attribute: fast-search
|
||||
}
|
||||
|
||||
field kg_relationships type weightedset<string> {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
|
||||
field kg_terms type weightedset<string> {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
|
||||
# Needs to have a separate Attribute list for efficient filtering
|
||||
field metadata_list type array<string> {
|
||||
indexing: summary | attribute
|
||||
|
||||
@@ -166,18 +166,19 @@ def _get_chunks_via_visit_api(
|
||||
|
||||
# build the list of fields to retrieve
|
||||
field_set_list = (
|
||||
None
|
||||
if not field_names
|
||||
else [f"{index_name}:{field_name}" for field_name in field_names]
|
||||
[f"{field_name}" for field_name in field_names] if field_names else []
|
||||
)
|
||||
acl_fieldset_entry = f"{index_name}:{ACCESS_CONTROL_LIST}"
|
||||
acl_fieldset_entry = f"{ACCESS_CONTROL_LIST}"
|
||||
if (
|
||||
field_set_list
|
||||
and filters.access_control_list
|
||||
and acl_fieldset_entry not in field_set_list
|
||||
):
|
||||
field_set_list.append(acl_fieldset_entry)
|
||||
field_set = ",".join(field_set_list) if field_set_list else None
|
||||
if field_set_list:
|
||||
field_set = f"{index_name}:" + ",".join(field_set_list)
|
||||
else:
|
||||
field_set = None
|
||||
|
||||
# build filters
|
||||
selection = f"{index_name}.document_id=='{chunk_request.document_id}'"
|
||||
|
||||
@@ -18,6 +18,7 @@ from uuid import UUID
|
||||
import httpx # type: ignore
|
||||
import jinja2
|
||||
import requests # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
@@ -30,6 +31,7 @@ from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.document_index_utils import get_document_chunk_ids
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
@@ -86,6 +88,17 @@ httpx_logger = logging.getLogger("httpx")
|
||||
httpx_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def update_kg_type_dict(
|
||||
dict_to_update: dict[str, dict], kg_type: str, value_set: set[str]
|
||||
) -> dict[str, dict]:
|
||||
if "fields" not in dict_to_update:
|
||||
dict_to_update["fields"] = {}
|
||||
dict_to_update["fields"][kg_type] = {
|
||||
"assign": {kg_type_object: 1 for kg_type_object in value_set}
|
||||
}
|
||||
return dict_to_update
|
||||
|
||||
|
||||
@dataclass
|
||||
class _VespaUpdateRequest:
|
||||
document_id: str
|
||||
@@ -93,6 +106,39 @@ class _VespaUpdateRequest:
|
||||
update_request: dict[str, dict]
|
||||
|
||||
|
||||
class KGVespaChunkUpdateRequest(BaseModel):
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
url: str
|
||||
update_request: dict[str, dict]
|
||||
|
||||
|
||||
class KGUChunkUpdateRequest(BaseModel):
|
||||
"""
|
||||
Update KG fields for a document
|
||||
"""
|
||||
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
core_entity: str
|
||||
entities: set[str] | None = None
|
||||
relationships: set[str] | None = None
|
||||
terms: set[str] | None = None
|
||||
converted_attributes: set[str] | None = None
|
||||
attributes: dict[str, str | list[str]] | None = None
|
||||
|
||||
|
||||
class KGUDocumentUpdateRequest(BaseModel):
|
||||
"""
|
||||
Update KG fields for a document
|
||||
"""
|
||||
|
||||
document_id: str
|
||||
entities: set[str]
|
||||
relationships: set[str]
|
||||
terms: set[str]
|
||||
|
||||
|
||||
def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
|
||||
zip_buffer = io.BytesIO()
|
||||
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf:
|
||||
@@ -501,6 +547,51 @@ class VespaIndex(DocumentIndex):
|
||||
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
@classmethod
|
||||
def _apply_kg_chunk_updates_batched(
|
||||
cls,
|
||||
updates: list[KGVespaChunkUpdateRequest],
|
||||
httpx_client: httpx.Client,
|
||||
batch_size: int = BATCH_SIZE,
|
||||
) -> None:
|
||||
"""Runs a batch of updates in parallel via the ThreadPoolExecutor."""
|
||||
|
||||
def _kg_update_chunk(
|
||||
update: KGVespaChunkUpdateRequest, http_client: httpx.Client
|
||||
) -> httpx.Response:
|
||||
# logger.debug(
|
||||
# f"Updating KG with request to {update.url} with body {update.update_request}"
|
||||
# )
|
||||
return http_client.put(
|
||||
update.url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update.update_request,
|
||||
)
|
||||
|
||||
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficient for
|
||||
# indexing / updates / deletes since we have to make a large volume of requests.
|
||||
|
||||
with (
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
|
||||
httpx_client as http_client,
|
||||
):
|
||||
for update_batch in batch_generator(updates, batch_size):
|
||||
future_to_document_id = {
|
||||
executor.submit(
|
||||
_kg_update_chunk,
|
||||
update,
|
||||
http_client,
|
||||
): update.document_id
|
||||
for update in update_batch
|
||||
}
|
||||
for future in concurrent.futures.as_completed(future_to_document_id):
|
||||
res = future.result()
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
|
||||
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
|
||||
|
||||
@@ -584,6 +675,77 @@ class VespaIndex(DocumentIndex):
|
||||
time.monotonic() - update_start,
|
||||
)
|
||||
|
||||
def kg_chunk_updates(
|
||||
self, kg_update_requests: list[KGUChunkUpdateRequest], tenant_id: str
|
||||
) -> None:
|
||||
|
||||
processed_updates_requests: list[KGVespaChunkUpdateRequest] = []
|
||||
logger.debug(f"Updating {len(kg_update_requests)} documents in Vespa")
|
||||
|
||||
update_start = time.monotonic()
|
||||
|
||||
# Build the _VespaUpdateRequest objects
|
||||
|
||||
for kg_update_request in kg_update_requests:
|
||||
kg_update_dict: dict[str, dict] = {"fields": {}}
|
||||
|
||||
implied_entities = set()
|
||||
if kg_update_request.relationships is not None:
|
||||
for kg_relationship in kg_update_request.relationships:
|
||||
kg_relationship_split = kg_relationship.split("__")
|
||||
if len(kg_relationship_split) == 3:
|
||||
implied_entities.add(kg_relationship_split[0])
|
||||
implied_entities.add(kg_relationship_split[2])
|
||||
|
||||
kg_update_dict = update_kg_type_dict(
|
||||
kg_update_dict, "kg_relationships", kg_update_request.relationships
|
||||
)
|
||||
|
||||
if kg_update_request.entities is not None or implied_entities:
|
||||
if kg_update_request.entities is None:
|
||||
kg_entities = implied_entities
|
||||
else:
|
||||
kg_entities = set(kg_update_request.entities)
|
||||
kg_entities.update(implied_entities)
|
||||
|
||||
kg_update_dict = update_kg_type_dict(
|
||||
kg_update_dict, "kg_entities", kg_entities
|
||||
)
|
||||
|
||||
if kg_update_request.terms is not None:
|
||||
kg_update_dict = update_kg_type_dict(
|
||||
kg_update_dict, "kg_terms", kg_update_request.terms
|
||||
)
|
||||
|
||||
if not kg_update_dict["fields"]:
|
||||
logger.error("Update request received but nothing to update")
|
||||
continue
|
||||
|
||||
doc_chunk_id = get_uuid_from_chunk_info(
|
||||
document_id=kg_update_request.document_id,
|
||||
chunk_id=kg_update_request.chunk_id,
|
||||
tenant_id=tenant_id,
|
||||
large_chunk_id=None,
|
||||
)
|
||||
|
||||
processed_updates_requests.append(
|
||||
KGVespaChunkUpdateRequest(
|
||||
document_id=kg_update_request.document_id,
|
||||
chunk_id=kg_update_request.chunk_id,
|
||||
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
|
||||
update_request=kg_update_dict,
|
||||
)
|
||||
)
|
||||
|
||||
with self.httpx_client_context as httpx_client:
|
||||
self._apply_kg_chunk_updates_batched(
|
||||
processed_updates_requests, httpx_client
|
||||
)
|
||||
logger.debug(
|
||||
"Finished updating Vespa documents in %.2f seconds",
|
||||
time.monotonic() - update_start,
|
||||
)
|
||||
|
||||
@retry(
|
||||
tries=3,
|
||||
delay=1,
|
||||
|
||||
80
backend/onyx/document_index/vespa/kg_interactions.py
Normal file
80
backend/onyx/document_index/vespa/kg_interactions.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
|
||||
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
|
||||
from onyx.document_index.vespa.index import IndexFilters
|
||||
from onyx.document_index.vespa.index import KGUChunkUpdateRequest
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
# from backend.onyx.chat.process_message import get_inference_chunks
|
||||
# from backend.onyx.document_index.vespa.index import VespaIndex
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class KGChunkInfo(BaseModel):
|
||||
kg_relationships: dict[str, int]
|
||||
kg_entities: dict[str, int]
|
||||
kg_terms: dict[str, int]
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def get_document_kg_info(
|
||||
document_id: str,
|
||||
index_name: str,
|
||||
filters: IndexFilters | None = None,
|
||||
) -> dict | None:
|
||||
"""
|
||||
Retrieve the kg_info attribute from a Vespa document by its document_id.
|
||||
Args:
|
||||
document_id: The unique identifier of the document.
|
||||
index_name: The name of the Vespa index to query.
|
||||
filters: Optional access control filters to apply.
|
||||
Returns:
|
||||
The kg_info dictionary if found, None otherwise.
|
||||
"""
|
||||
# Use the existing visit API infrastructure
|
||||
kg_doc_info: dict[int, KGChunkInfo] = {}
|
||||
|
||||
document_chunks = _get_chunks_via_visit_api(
|
||||
chunk_request=VespaChunkRequest(document_id=document_id),
|
||||
index_name=index_name,
|
||||
filters=filters or IndexFilters(access_control_list=None),
|
||||
field_names=["kg_relationships", "kg_entities", "kg_terms"],
|
||||
get_large_chunks=False,
|
||||
)
|
||||
|
||||
for chunk_id, document_chunk in enumerate(document_chunks):
|
||||
kg_chunk_info = KGChunkInfo(
|
||||
kg_relationships=document_chunk["fields"].get("kg_relationships", {}),
|
||||
kg_entities=document_chunk["fields"].get("kg_entities", {}),
|
||||
kg_terms=document_chunk["fields"].get("kg_terms", {}),
|
||||
)
|
||||
|
||||
kg_doc_info[chunk_id] = kg_chunk_info # TODO: check the chunk id is correct!
|
||||
|
||||
return kg_doc_info
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def update_kg_chunks_vespa_info(
|
||||
kg_update_requests: list[KGUChunkUpdateRequest],
|
||||
index_name: str,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
""" """
|
||||
# Use the existing visit API infrastructure
|
||||
vespa_index = VespaIndex(
|
||||
index_name=index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=False,
|
||||
multitenant=False,
|
||||
httpx_client=None,
|
||||
)
|
||||
|
||||
vespa_index.kg_chunk_updates(
|
||||
kg_update_requests=kg_update_requests, tenant_id=tenant_id
|
||||
)
|
||||
@@ -54,6 +54,47 @@ def build_vespa_filters(
|
||||
|
||||
return result
|
||||
|
||||
def _build_kg_filter(
|
||||
kg_entities: list[str] | None,
|
||||
kg_relationships: list[str] | None,
|
||||
kg_terms: list[str] | None,
|
||||
) -> str:
|
||||
if not kg_entities and not kg_relationships and not kg_terms:
|
||||
return ""
|
||||
|
||||
filter_parts = []
|
||||
|
||||
# Process each filter type using the same pattern
|
||||
for filter_type, values in [
|
||||
("kg_entities", kg_entities),
|
||||
("kg_relationships", kg_relationships),
|
||||
("kg_terms", kg_terms),
|
||||
]:
|
||||
if values:
|
||||
filter_parts.append(
|
||||
" and ".join(f'({filter_type} contains "{val}") ' for val in values)
|
||||
)
|
||||
|
||||
return f"({' and '.join(filter_parts)}) and "
|
||||
|
||||
def _build_kg_source_filters(
|
||||
kg_sources: list[str] | None,
|
||||
) -> str:
|
||||
if not kg_sources:
|
||||
return ""
|
||||
|
||||
source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources]
|
||||
|
||||
return f"({' or '.join(source_phrases)}) and "
|
||||
|
||||
def _build_kg_chunk_id_zero_only_filter(
|
||||
kg_chunk_id_zero_only: bool,
|
||||
) -> str:
|
||||
if not kg_chunk_id_zero_only:
|
||||
return ""
|
||||
|
||||
return "(chunk_id = 0 ) and "
|
||||
|
||||
def _build_time_filter(
|
||||
cutoff: datetime | None,
|
||||
untimed_doc_cutoff: timedelta = timedelta(days=92),
|
||||
@@ -106,6 +147,19 @@ def build_vespa_filters(
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
# Knowledge Graph Filters
|
||||
filter_str += _build_kg_filter(
|
||||
kg_entities=filters.kg_entities,
|
||||
kg_relationships=filters.kg_relationships,
|
||||
kg_terms=filters.kg_terms,
|
||||
)
|
||||
|
||||
filter_str += _build_kg_source_filters(filters.kg_sources)
|
||||
|
||||
filter_str += _build_kg_chunk_id_zero_only_filter(
|
||||
filters.kg_chunk_id_zero_only or False
|
||||
)
|
||||
|
||||
# Trim trailing " and "
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
|
||||
178
backend/onyx/kg/clustering/incremental_cluster_updates.py
Normal file
178
backend/onyx/kg/clustering/incremental_cluster_updates.py
Normal file
@@ -0,0 +1,178 @@
|
||||
from typing import Set
|
||||
|
||||
from onyx.db.document import update_document_kg_info
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import add_entity
|
||||
from onyx.db.entities import delete_entities_by_id_names
|
||||
from onyx.db.entities import get_entities_by_grounding
|
||||
from onyx.db.relationships import add_relationship
|
||||
from onyx.db.relationships import add_relationship_type
|
||||
from onyx.db.relationships import delete_relationship_types_by_id_names
|
||||
from onyx.db.relationships import delete_relationships_by_id_names
|
||||
from onyx.db.relationships import get_all_relationship_types
|
||||
from onyx.db.relationships import get_all_relationships
|
||||
from onyx.kg.models import KGGroundingType
|
||||
from onyx.kg.models import KGStage
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def kg_incremental_cluster_updates(
|
||||
tenant_id: str, index_name: str, processing_chunk_batch_size: int = 8
|
||||
) -> None:
|
||||
"""
|
||||
Here we will cluster the extractions based on their cluster frameworks.
|
||||
Initially, this will only focus on grounded entities with pre-determined
|
||||
relationships, so clustering is actually not yet required.
|
||||
The primary purpose of this function is to populate the actual KG tables
|
||||
from the temp_extraction tables.
|
||||
|
||||
This will change with deep extraction, where grounded-sourceless entities
|
||||
can be extracted and then need to be clustered.
|
||||
"""
|
||||
|
||||
logger.info(f"Starting kg clustering for tenant {tenant_id}")
|
||||
|
||||
## Retrieval
|
||||
|
||||
source_documents_w_successful_transfers: Set[str] = set()
|
||||
source_documents_w_failed_transfers: Set[str] = set()
|
||||
|
||||
# get objects that are now in the Staging tables
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
relationship_types = get_all_relationship_types(
|
||||
db_session, kg_stage=KGStage.EXTRACTED
|
||||
)
|
||||
|
||||
relationships = get_all_relationships(db_session, kg_stage=KGStage.EXTRACTED)
|
||||
|
||||
grounded_entities = get_entities_by_grounding(
|
||||
db_session, KGStage.EXTRACTED, KGGroundingType.GROUNDED
|
||||
)
|
||||
|
||||
## Clustering
|
||||
|
||||
# TODO: we will re-implement the cluster matching logic here
|
||||
|
||||
## Database operations
|
||||
|
||||
# create the clustered objects - entities
|
||||
|
||||
transferred_entities: list[str] = []
|
||||
for grounded_entity in grounded_entities:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
added_entity = add_entity(
|
||||
db_session,
|
||||
KGStage.NORMALIZED,
|
||||
entity_type=grounded_entity.entity_type_id_name,
|
||||
name=grounded_entity.name,
|
||||
occurrences=grounded_entity.occurrences or 1,
|
||||
document_id=grounded_entity.document_id or None,
|
||||
attributes=grounded_entity.attributes or None,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
if added_entity:
|
||||
transferred_entities.append(added_entity.id_name)
|
||||
|
||||
transferred_relationship_types: list[str] = []
|
||||
for relationship_type in relationship_types:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
added_relationship_type_id_name = add_relationship_type(
|
||||
db_session,
|
||||
KGStage.NORMALIZED,
|
||||
source_entity_type=relationship_type.source_entity_type_id_name,
|
||||
relationship_type=relationship_type.type,
|
||||
target_entity_type=relationship_type.target_entity_type_id_name,
|
||||
extraction_count=relationship_type.occurrences or 1,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
transferred_relationship_types.append(added_relationship_type_id_name)
|
||||
|
||||
transferred_relationships: list[str] = []
|
||||
for relationship in relationships:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
added_relationship = add_relationship(
|
||||
db_session,
|
||||
KGStage.NORMALIZED,
|
||||
relationship_id_name=relationship.id_name,
|
||||
source_document_id=relationship.source_document or "",
|
||||
occurrences=relationship.occurrences or 1,
|
||||
)
|
||||
|
||||
if relationship.source_document:
|
||||
source_documents_w_successful_transfers.add(
|
||||
relationship.source_document
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
transferred_relationships.append(added_relationship.id_name)
|
||||
|
||||
except Exception as e:
|
||||
if relationship.source_document:
|
||||
source_documents_w_failed_transfers.add(
|
||||
relationship.source_document
|
||||
)
|
||||
logger.error(
|
||||
f"Error transferring relationship {relationship.id_name}: {e}"
|
||||
)
|
||||
|
||||
# TODO: remove the /relationship types & entities that correspond to relationships
|
||||
# source documents that failed to transfer. I.e, do a proper rollback
|
||||
|
||||
# TODO: update Vespa info when clustering/changes are performed
|
||||
|
||||
# delete the added objects from the staging tables
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_relationships_by_id_names(
|
||||
db_session, transferred_relationships, kg_stage=KGStage.EXTRACTED
|
||||
)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relationships: {e}")
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_relationship_types_by_id_names(
|
||||
db_session, transferred_relationship_types, kg_stage=KGStage.EXTRACTED
|
||||
)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relationship types: {e}")
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_entities_by_id_names(
|
||||
db_session, transferred_entities, kg_stage=KGStage.EXTRACTED
|
||||
)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting entities: {e}")
|
||||
|
||||
# Update document kg info
|
||||
|
||||
# with get_session_with_current_tenant() as db_session:
|
||||
# all_kg_extracted_documents_info = get_all_kg_extracted_documents_info(
|
||||
# db_session
|
||||
# )
|
||||
|
||||
for document_id in source_documents_w_successful_transfers:
|
||||
|
||||
# Update the document kg info
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
update_document_kg_info(
|
||||
db_session,
|
||||
document_id=document_id,
|
||||
kg_stage=KGStage.NORMALIZED,
|
||||
)
|
||||
db_session.commit()
|
||||
638
backend/onyx/kg/clustering/initial_clustering.py
Normal file
638
backend/onyx/kg/clustering/initial_clustering.py
Normal file
@@ -0,0 +1,638 @@
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Set
|
||||
|
||||
import numpy as np
|
||||
from sklearn.cluster import SpectralClustering # type: ignore
|
||||
from thefuzz import fuzz # type: ignore
|
||||
|
||||
from onyx.db.document import update_document_kg_info
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import add_entity
|
||||
from onyx.db.entities import delete_entities_by_id_names
|
||||
from onyx.db.entities import get_entities_by_grounding
|
||||
from onyx.db.entity_type import get_determined_grounded_entity_types
|
||||
from onyx.db.relationships import add_relationship
|
||||
from onyx.db.relationships import add_relationship_type
|
||||
from onyx.db.relationships import delete_relationship_types_by_id_names
|
||||
from onyx.db.relationships import delete_relationships_by_id_names
|
||||
from onyx.db.relationships import get_all_relationship_types
|
||||
from onyx.db.relationships import get_all_relationships
|
||||
from onyx.kg.models import KGGroundingType
|
||||
from onyx.kg.models import KGStage
|
||||
from onyx.kg.utils.embeddings import encode_string_batch
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _create_ge_determined_entity_map() -> Dict[str, List[str]]:
|
||||
"""Create a mapping of entity type ID names to their grounding determination instructions.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping entity type ID names to their list of grounding determination instructions
|
||||
"""
|
||||
ge_determined_entity_map: Dict[str, List[str]] = defaultdict(list)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
determined_entities = get_determined_grounded_entity_types(db_session)
|
||||
|
||||
for entity_type in determined_entities:
|
||||
if entity_type.entity_values: # Extra safety check
|
||||
ge_determined_entity_map[entity_type.id_name] = (
|
||||
entity_type.entity_values
|
||||
)
|
||||
|
||||
return ge_determined_entity_map
|
||||
|
||||
|
||||
def _cluster_relationships(
|
||||
relationship_data: List[dict], n_clusters: int = 3, batch_size: int = 12
|
||||
) -> Dict[int, List[str]]:
|
||||
"""
|
||||
Cluster relationships using their embeddings.
|
||||
|
||||
Args:
|
||||
relationship_data: List of dicts with 'name' and 'cluster_count'
|
||||
n_clusters: Number of clusters to create
|
||||
batch_size: Size of batches for embedding requests
|
||||
|
||||
Returns:
|
||||
Dictionary mapping cluster IDs to lists of relationship names
|
||||
"""
|
||||
|
||||
# TODO: This is TEMP for the pre-defined relationships.
|
||||
# if len(relationship_data) < n_clusters:
|
||||
if len(relationship_data) < n_clusters:
|
||||
logger.warning(
|
||||
"Not enough relationships to cluster. Returning each relationship as its own cluster."
|
||||
)
|
||||
return {i: [rel["name"]] for (i, rel) in enumerate(relationship_data)}
|
||||
|
||||
train_data = []
|
||||
rel_names = []
|
||||
|
||||
# Process relationships in batches
|
||||
for i in range(0, len(relationship_data), batch_size):
|
||||
batch = relationship_data[i : i + batch_size]
|
||||
batch_names = [
|
||||
rel["name"].replace("_", " ") for rel in batch
|
||||
] # better for LLM to have spaces between words
|
||||
|
||||
# Get embeddings for the entire batch at once
|
||||
batch_embeddings = encode_string_batch(batch_names)
|
||||
|
||||
# Add embeddings and corresponding data
|
||||
for rel, embedding in zip(batch, batch_embeddings):
|
||||
count = int(rel["cluster_count"]) or 1
|
||||
# Add the relationship name 'count' times
|
||||
for _ in range(count):
|
||||
train_data.append(embedding)
|
||||
rel_names.append(rel["name"])
|
||||
|
||||
# Convert to numpy arrays
|
||||
X = np.array(train_data)
|
||||
|
||||
# Perform clustering
|
||||
# clustering = KMeans(n_clusters=n_clusters, random_state=42)
|
||||
clustering = SpectralClustering(n_clusters=n_clusters, random_state=42)
|
||||
clusters = clustering.fit_predict(X)
|
||||
|
||||
# Group relationship names by cluster
|
||||
cluster_groups: Dict[int, List[str]] = defaultdict(list)
|
||||
for rel_name, cluster_id in zip(rel_names, clusters):
|
||||
if rel_name not in cluster_groups[cluster_id]:
|
||||
cluster_groups[cluster_id].append(rel_name)
|
||||
|
||||
return dict(cluster_groups)
|
||||
|
||||
|
||||
def _cluster_entities(
|
||||
entity_data: List[dict], n_clusters: int = 3, batch_size: int = 12
|
||||
) -> Dict[int, List[str]]:
|
||||
"""
|
||||
Cluster entities using their embeddings.
|
||||
|
||||
Args:
|
||||
entity_data: List of dicts with 'name' and 'cluster_count'
|
||||
n_clusters: Number of clusters to create
|
||||
batch_size: Size of batches for embedding requests
|
||||
|
||||
Returns:
|
||||
Dictionary mapping cluster IDs to lists of entity names
|
||||
"""
|
||||
|
||||
if len(entity_data) < n_clusters:
|
||||
logger.warning(
|
||||
"Not enough entities to cluster. Returning each entity as its own cluster."
|
||||
)
|
||||
return {
|
||||
i: [ent["name"] for ent in entity_data] for i in range(len(entity_data))
|
||||
}
|
||||
|
||||
train_data = []
|
||||
entity_names = []
|
||||
|
||||
# Process entities in batches
|
||||
for i in range(0, len(entity_data), batch_size):
|
||||
batch = entity_data[i : i + batch_size]
|
||||
batch_names = [
|
||||
ent["name"].replace("_", " ") for ent in batch
|
||||
] # use spaces between words for LLM
|
||||
|
||||
# Get embeddings for the entire batch at once
|
||||
batch_embeddings = encode_string_batch(batch_names)
|
||||
|
||||
# Add embeddings and corresponding data
|
||||
for ent, embedding in zip(batch, batch_embeddings):
|
||||
count = int(ent["cluster_count"]) or 1
|
||||
|
||||
# Add the entity name 'count' times
|
||||
for _ in range(count):
|
||||
entity_names.append(ent["name"])
|
||||
train_data.append(embedding)
|
||||
|
||||
# Convert to numpy arrays
|
||||
X = np.array(train_data)
|
||||
|
||||
# Perform clustering
|
||||
# clustering = KMeans(n_clusters=n_clusters, random_state=42)
|
||||
clustering = SpectralClustering(n_clusters=n_clusters, random_state=42)
|
||||
clusters = clustering.fit_predict(X)
|
||||
|
||||
# Group entity names by cluster
|
||||
cluster_groups: Dict[int, List[str]] = defaultdict(list)
|
||||
for ent_name, cluster_id in zip(entity_names, clusters):
|
||||
if ent_name not in cluster_groups[cluster_id]:
|
||||
cluster_groups[cluster_id].append(ent_name)
|
||||
|
||||
return dict(cluster_groups)
|
||||
|
||||
|
||||
def _create_relationship_type_mapping(
|
||||
full_clustering_results: Dict[str, Dict[str, Dict[int, Dict[str, Any]]]],
|
||||
relationship_mapping: Dict[str, Dict[str, List[dict]]],
|
||||
) -> tuple[Dict[str, str], Dict[str, int]]:
|
||||
"""
|
||||
Create a mapping between original relationship types and their clustered versions.
|
||||
|
||||
Args:
|
||||
full_clustering_results: Clustering results with cluster names
|
||||
relationship_mapping: Original relationship types organized by source/target
|
||||
|
||||
Returns:
|
||||
Dictionary mapping original relationship type ID to clustered relationship type ID
|
||||
"""
|
||||
relationship_type_replacements: Dict[str, str] = {}
|
||||
reverse_relationship_type_replacements_count: Dict[str, int] = defaultdict(int)
|
||||
|
||||
for source_type, target_dict in relationship_mapping.items():
|
||||
for target_type, rel_types in target_dict.items():
|
||||
# Get clusters for this source/target pair
|
||||
clusters = full_clustering_results.get(source_type, {}).get(target_type, {})
|
||||
|
||||
for cluster_id, cluster_info in clusters.items():
|
||||
cluster_name = cluster_info["cluster_name"]
|
||||
for rel_name in cluster_info["relationships"]:
|
||||
original_id = f"{source_type}__{rel_name.lower()}__{target_type}"
|
||||
clustered_id = (
|
||||
f"{source_type}__{cluster_name.lower()}__{target_type}"
|
||||
)
|
||||
relationship_type_replacements[original_id] = clustered_id
|
||||
reverse_relationship_type_replacements_count[clustered_id] += len(
|
||||
cluster_info["relationships"]
|
||||
)
|
||||
|
||||
return relationship_type_replacements, reverse_relationship_type_replacements_count
|
||||
|
||||
|
||||
def _create_entity_mapping(
|
||||
full_entity_clustering_results: Dict[str, Dict[int, Dict[str, Any]]],
|
||||
entity_mapping: Dict[str, List[dict]],
|
||||
) -> tuple[Dict[str, str], Dict[str, int]]:
|
||||
"""
|
||||
Create a mapping between original entities and their clustered versions.
|
||||
|
||||
Args:
|
||||
full_entity_clustering_results: Clustering results with cluster names
|
||||
entity_mapping: Original entities organized by entity type
|
||||
|
||||
Returns:
|
||||
Dictionary mapping original entity ID to clustered entity ID
|
||||
"""
|
||||
entity_replacements: Dict[str, str] = {}
|
||||
reverse_entity_replacements_count: Dict[str, int] = defaultdict(int)
|
||||
|
||||
for entity_type, clusters in full_entity_clustering_results.items():
|
||||
for cluster_id, cluster_info in clusters.items():
|
||||
cluster_name = cluster_info["cluster_name"]
|
||||
for entity_name in cluster_info["entities"]:
|
||||
# Skip wildcard entities
|
||||
if entity_name == "*":
|
||||
continue
|
||||
|
||||
original_id = f"{entity_type}:{entity_name}"
|
||||
clustered_id = f"{entity_type}:{cluster_name.title()}"
|
||||
entity_replacements[original_id] = clustered_id
|
||||
reverse_entity_replacements_count[clustered_id] += len(
|
||||
cluster_info["entities"]
|
||||
)
|
||||
return entity_replacements, reverse_entity_replacements_count
|
||||
|
||||
|
||||
def _create_relationship_mapping(
|
||||
relationship_type_replacements: Dict[str, str],
|
||||
reverse_relationship_type_replacements_count: Dict[str, int],
|
||||
entity_replacements: Dict[str, str],
|
||||
reverse_entity_replacements_count: Dict[str, int],
|
||||
relationships: List[
|
||||
Any
|
||||
], # This would be List[KGRelationship] but avoiding the import
|
||||
) -> tuple[Dict[str, str], Dict[str, int]]:
|
||||
"""
|
||||
Create a mapping between original relationships and their clustered versions,
|
||||
taking into account both clustered relationship types and clustered entities.
|
||||
|
||||
Args:
|
||||
relationship_type_replacements: Mapping of original to clustered relationship type IDs
|
||||
entity_replacements: Mapping of original to clustered entity IDs
|
||||
relationships: List of relationships from the database
|
||||
|
||||
Returns:
|
||||
Dictionary mapping original relationship ID to clustered relationship ID
|
||||
"""
|
||||
relationship_replacements: Dict[str, str] = {}
|
||||
reverse_relationship_replacements_count: Dict[str, int] = defaultdict(int)
|
||||
|
||||
for rel in relationships:
|
||||
# Skip if source or target is a wildcard
|
||||
|
||||
# Get the clustered entities (if they exist)
|
||||
source_node = entity_replacements.get(rel.source_node, rel.source_node)
|
||||
target_node = entity_replacements.get(rel.target_node, rel.target_node)
|
||||
|
||||
rel.source_document
|
||||
|
||||
# Create the relationship type ID
|
||||
source_type = rel.source_node.split(":")[0]
|
||||
target_type = rel.target_node.split(":")[0]
|
||||
rel_type_id = f"{source_type}__{rel.type.lower()}__{target_type}"
|
||||
|
||||
# Get the clustered relationship type (if it exists)
|
||||
clustered_rel_type_id = relationship_type_replacements.get(
|
||||
rel_type_id, rel_type_id
|
||||
)
|
||||
|
||||
# Extract the relationship name from the clustered type ID
|
||||
_, rel_name, _ = clustered_rel_type_id.split("__")
|
||||
|
||||
# Create the original and clustered relationship IDs
|
||||
original_id = f"{rel.source_node}__{rel.type.lower()}__{rel.target_node}"
|
||||
clustered_id = f"{source_node}__{rel_name}__{target_node}"
|
||||
|
||||
relationship_replacements[original_id] = clustered_id
|
||||
reverse_relationship_replacements_count[clustered_id] += rel.occurrences or 1
|
||||
|
||||
return relationship_replacements, reverse_relationship_replacements_count
|
||||
|
||||
|
||||
def _match_ungrounded_ge_entities(
|
||||
ungrounded_ge_entities: Dict[str, List[str]],
|
||||
grounded_ge_entities: Dict[str, List[str]],
|
||||
fuzzy_match_threshold: int = 80,
|
||||
) -> Dict[str, Dict[str, str]]:
|
||||
"""
|
||||
Create a mapping for ungrounded entities by matching them to grounded entities
|
||||
or previously processed ungrounded entities. First checks for containment relationships,
|
||||
then falls back to fuzzy matching if no containment is found.
|
||||
|
||||
Args:
|
||||
ungrounded_ge_entities: Dictionary mapping entity types to lists of ungrounded entity names
|
||||
grounded_ge_entities: Dictionary mapping entity types to lists of grounded entity names
|
||||
fuzzy_match_threshold: Threshold for fuzzy matching (0-100)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping entity types to dictionaries of {original_entity: matched_entity}
|
||||
"""
|
||||
entity_match_mapping: Dict[str, Dict[str, str]] = defaultdict(dict)
|
||||
processed_entities: Dict[str, Set[str]] = defaultdict(set)
|
||||
|
||||
# For each entity type
|
||||
for entity_type, ungrounded_entities_list in ungrounded_ge_entities.items():
|
||||
grounded_list = grounded_ge_entities.get(entity_type, [])
|
||||
|
||||
# Process each ungrounded entity
|
||||
for ungrounded_entity in ungrounded_entities_list:
|
||||
if ungrounded_entity == "*":
|
||||
continue
|
||||
best_match = None
|
||||
|
||||
# First check if ungrounded entity is contained in or contains any grounded entities
|
||||
for grounded_entity in grounded_list:
|
||||
if (
|
||||
ungrounded_entity.lower() in grounded_entity.lower()
|
||||
or grounded_entity.lower() in ungrounded_entity.lower()
|
||||
):
|
||||
best_match = grounded_entity
|
||||
break
|
||||
|
||||
# If no containment match with grounded entities, check previously processed ungrounded entities
|
||||
if not best_match:
|
||||
for processed_entity in processed_entities[entity_type]:
|
||||
if (
|
||||
ungrounded_entity.lower() in processed_entity.lower()
|
||||
or processed_entity.lower() in ungrounded_entity.lower()
|
||||
):
|
||||
best_match = processed_entity
|
||||
break
|
||||
|
||||
# If still no match, fall back to fuzzy matching
|
||||
if not best_match:
|
||||
best_score = 0
|
||||
|
||||
# Try fuzzy matching with grounded entities
|
||||
for grounded_entity in grounded_list:
|
||||
score = fuzz.ratio(
|
||||
ungrounded_entity.lower(), grounded_entity.lower()
|
||||
)
|
||||
if score > fuzzy_match_threshold and score > best_score:
|
||||
best_match = grounded_entity
|
||||
best_score = score
|
||||
|
||||
# Try fuzzy matching with previously processed ungrounded entities
|
||||
if not best_match:
|
||||
for processed_entity in processed_entities[entity_type]:
|
||||
score = fuzz.ratio(
|
||||
ungrounded_entity.lower(), processed_entity.lower()
|
||||
)
|
||||
if score > fuzzy_match_threshold and score > best_score:
|
||||
best_match = processed_entity
|
||||
best_score = score
|
||||
|
||||
# Record the mapping
|
||||
if best_match:
|
||||
entity_match_mapping[entity_type][ungrounded_entity] = best_match
|
||||
else:
|
||||
# No match found, this becomes a new unique entity
|
||||
entity_match_mapping[entity_type][ungrounded_entity] = ungrounded_entity
|
||||
processed_entities[entity_type].add(ungrounded_entity)
|
||||
|
||||
# Log the results
|
||||
logger.info("Entity matching results:")
|
||||
for entity_type, mappings in entity_match_mapping.items():
|
||||
logger.info(f"\nEntity type: {entity_type}")
|
||||
for original, matched in mappings.items():
|
||||
if original != matched:
|
||||
logger.info(f" Mapped: {original} -> {matched}")
|
||||
else:
|
||||
logger.info(f" New unique entity: {original}")
|
||||
|
||||
return entity_match_mapping
|
||||
|
||||
|
||||
def _match_determined_ge_entities(
|
||||
determined_ge_entity_map: Dict[str, List[str]],
|
||||
determined_ge_entities_by_type: Dict[str, List[str]],
|
||||
fuzzy_match_threshold: int = 80,
|
||||
) -> Dict[str, Dict[str, str]]:
|
||||
"""
|
||||
Create a mapping for determined entities by matching them to grounded entities
|
||||
or previously processed ungrounded entities. First checks for containment relationships,
|
||||
then falls back to fuzzy matching if no containment is found.
|
||||
|
||||
Args:
|
||||
ungrounded_ge_entities: Dictionary mapping entity types to lists of ungrounded entity names
|
||||
grounded_ge_entities: Dictionary mapping entity types to lists of grounded entity names
|
||||
fuzzy_match_threshold: Threshold for fuzzy matching (0-100)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping entity types to dictionaries of {original_entity: matched_entity}
|
||||
"""
|
||||
determined_entity_match_mapping: Dict[str, Dict[str, str]] = defaultdict(dict)
|
||||
|
||||
# For each entity type
|
||||
for entity_type, determined_entities_list in determined_ge_entity_map.items():
|
||||
ungrounded_list = determined_ge_entities_by_type.get(entity_type, [])
|
||||
|
||||
# Process each ungrounded entity
|
||||
for ungrounded_entity in ungrounded_list:
|
||||
if ungrounded_entity == "*":
|
||||
continue
|
||||
best_match = None
|
||||
|
||||
# First check if ungrounded entity is contained in or contains any grounded entities
|
||||
for grounded_entity in determined_entities_list:
|
||||
if (
|
||||
ungrounded_entity.lower() in grounded_entity.lower()
|
||||
or grounded_entity.lower() in ungrounded_entity.lower()
|
||||
):
|
||||
best_match = grounded_entity
|
||||
break
|
||||
|
||||
# If still no match, fall back to fuzzy matching
|
||||
if not best_match:
|
||||
best_score = 0
|
||||
|
||||
# Try fuzzy matching with grounded entities
|
||||
for grounded_entity in determined_entities_list:
|
||||
score = fuzz.ratio(
|
||||
ungrounded_entity.lower(), grounded_entity.lower()
|
||||
)
|
||||
if score > fuzzy_match_threshold and score > best_score:
|
||||
best_match = grounded_entity
|
||||
best_score = score
|
||||
|
||||
# Record the mapping
|
||||
if best_match:
|
||||
determined_entity_match_mapping[entity_type][
|
||||
f"{ungrounded_entity}"
|
||||
] = f"{best_match}"
|
||||
else:
|
||||
# No match found, this becomes a new unique entity
|
||||
determined_entity_match_mapping[entity_type][
|
||||
f"{ungrounded_entity}"
|
||||
] = "Other"
|
||||
|
||||
# Log the results
|
||||
logger.info("Entity matching results:")
|
||||
for entity_type, mappings in determined_entity_match_mapping.items():
|
||||
logger.info(f"\nEntity type: {entity_type}")
|
||||
for original, matched in mappings.items():
|
||||
if original != matched:
|
||||
logger.info(f" Mapped: {original} -> {matched}")
|
||||
else:
|
||||
logger.info(f" New unique entity: {original}")
|
||||
|
||||
return determined_entity_match_mapping
|
||||
|
||||
|
||||
def kg_clustering(
|
||||
tenant_id: str, index_name: str, processing_chunk_batch_size: int = 8
|
||||
) -> None:
|
||||
"""
|
||||
Here we will cluster the extractions based on their cluster frameworks.
|
||||
Initially, this will only focus on grounded entities with pre-determined
|
||||
relationships, so 'clustering' is actually not yet required.
|
||||
However, we may need to reconcile entities coming from different sources.
|
||||
|
||||
The primary purpose of this function is to populate the actual KG tables
|
||||
from the temp_extraction tables.
|
||||
|
||||
This will change with deep extraction, where grounded-sourceless entities
|
||||
can be extracted and then need to be clustered.
|
||||
"""
|
||||
|
||||
logger.info(f"Starting kg clustering for tenant {tenant_id}")
|
||||
|
||||
## Retrieval
|
||||
|
||||
source_documents_w_successful_transfers: Set[str] = set()
|
||||
source_documents_w_failed_transfers: Set[str] = set()
|
||||
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
relationship_types = get_all_relationship_types(
|
||||
db_session, kg_stage=KGStage.EXTRACTED
|
||||
)
|
||||
|
||||
relationships = get_all_relationships(db_session, kg_stage=KGStage.EXTRACTED)
|
||||
|
||||
grounded_entities = get_entities_by_grounding(
|
||||
db_session, KGStage.EXTRACTED, KGGroundingType.GROUNDED
|
||||
)
|
||||
|
||||
## Clustering
|
||||
|
||||
# TODO: re-implement clustering of ungrounded entities as well as
|
||||
# grounded entities that do not have a source document with deep extraction
|
||||
# enabled!
|
||||
# For now we would just create a trivial entity mapping from the
|
||||
# 'unclustered' entities to the 'clustered' entities. So we can simply
|
||||
# transfer the entity information from the Staging to the Normalized
|
||||
# tables.
|
||||
# This will be reimplemented when deep extraction is enabled.
|
||||
|
||||
## Database operations
|
||||
|
||||
# create the clustered objects - entities
|
||||
|
||||
transferred_entities: list[str] = []
|
||||
for grounded_entity in grounded_entities:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
added_entity = add_entity(
|
||||
db_session,
|
||||
KGStage.NORMALIZED,
|
||||
entity_type=grounded_entity.entity_type_id_name,
|
||||
name=grounded_entity.name,
|
||||
occurrences=grounded_entity.occurrences or 1,
|
||||
document_id=grounded_entity.document_id or None,
|
||||
attributes=grounded_entity.attributes or None,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
if added_entity:
|
||||
transferred_entities.append(added_entity.id_name)
|
||||
|
||||
transferred_relationship_types: list[str] = []
|
||||
for relationship_type in relationship_types:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
added_relationship_type_id_name = add_relationship_type(
|
||||
db_session,
|
||||
KGStage.NORMALIZED,
|
||||
source_entity_type=relationship_type.source_entity_type_id_name,
|
||||
relationship_type=relationship_type.type,
|
||||
target_entity_type=relationship_type.target_entity_type_id_name,
|
||||
extraction_count=relationship_type.occurrences or 1,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
transferred_relationship_types.append(added_relationship_type_id_name)
|
||||
|
||||
transferred_relationships: list[str] = []
|
||||
for relationship in relationships:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
added_relationship = add_relationship(
|
||||
db_session,
|
||||
KGStage.NORMALIZED,
|
||||
relationship_id_name=relationship.id_name,
|
||||
source_document_id=relationship.source_document or "",
|
||||
occurrences=relationship.occurrences or 1,
|
||||
)
|
||||
|
||||
if relationship.source_document:
|
||||
source_documents_w_successful_transfers.add(
|
||||
relationship.source_document
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
transferred_relationships.append(added_relationship.id_name)
|
||||
|
||||
except Exception as e:
|
||||
if relationship.source_document:
|
||||
source_documents_w_failed_transfers.add(
|
||||
relationship.source_document
|
||||
)
|
||||
logger.error(
|
||||
f"Error transferring relationship {relationship.id_name}: {e}"
|
||||
)
|
||||
|
||||
# TODO: remove the /relationship types & entities that correspond to relationships
|
||||
# source documents that failed to transfer. I.e, do a proper rollback
|
||||
|
||||
# TODO: update Vespa info when clustering/changes are performed
|
||||
|
||||
# delete the added objects from the staging tables
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_relationships_by_id_names(
|
||||
db_session, transferred_relationships, kg_stage=KGStage.EXTRACTED
|
||||
)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relationships: {e}")
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_relationship_types_by_id_names(
|
||||
db_session, transferred_relationship_types, kg_stage=KGStage.EXTRACTED
|
||||
)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting relationship types: {e}")
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_entities_by_id_names(
|
||||
db_session, transferred_entities, kg_stage=KGStage.EXTRACTED
|
||||
)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting entities: {e}")
|
||||
|
||||
# Update document kg info
|
||||
|
||||
# with get_session_with_current_tenant() as db_session:
|
||||
# all_kg_extracted_documents_info = get_all_kg_extracted_documents_info(
|
||||
# db_session
|
||||
# )
|
||||
|
||||
for document_id in source_documents_w_successful_transfers:
|
||||
|
||||
# Update the document kg info
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
update_document_kg_info(
|
||||
db_session,
|
||||
document_id=document_id,
|
||||
kg_stage=KGStage.NORMALIZED,
|
||||
)
|
||||
db_session.commit()
|
||||
311
backend/onyx/kg/clustering/normalizations.py
Normal file
311
backend/onyx/kg/clustering/normalizations.py
Normal file
@@ -0,0 +1,311 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from thefuzz import fuzz # type: ignore
|
||||
from thefuzz import process # type: ignore
|
||||
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_entity_names_for_types
|
||||
from onyx.db.relationships import get_relationships_for_entity_type_pairs
|
||||
from onyx.kg.models import NormalizedEntities
|
||||
from onyx.kg.models import NormalizedRelationships
|
||||
from onyx.kg.models import NormalizedTerms
|
||||
from onyx.kg.utils.embeddings import encode_string_batch
|
||||
|
||||
|
||||
def _split_entity_type_v_name(entity: str) -> tuple[str, str]:
|
||||
"""
|
||||
Split an entity string into type and name.
|
||||
"""
|
||||
|
||||
entity_split = entity.split(":")
|
||||
if len(entity_split) < 2:
|
||||
raise ValueError(f"Invalid entity: {entity}")
|
||||
|
||||
entity_type = entity_split[0]
|
||||
entity_name = ":".join(entity_split[1:])
|
||||
|
||||
return entity_type, entity_name
|
||||
|
||||
|
||||
def _get_existing_normalized_entities(
|
||||
raw_entities: List[str],
|
||||
) -> List[tuple[str, str | None]]:
|
||||
"""
|
||||
Get existing normalized entities from the database.
|
||||
"""
|
||||
|
||||
entity_types = list(set([entity.split(":")[0] for entity in raw_entities]))
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
entities = get_entity_names_for_types(db_session, entity_types)
|
||||
|
||||
return entities
|
||||
|
||||
|
||||
def _get_existing_normalized_relationships(
|
||||
raw_relationships: List[str],
|
||||
) -> Dict[str, Dict[str, List[str]]]:
|
||||
"""
|
||||
Get existing normalized relationships from the database.
|
||||
"""
|
||||
|
||||
relationship_type_map: Dict[str, Dict[str, List[str]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
relationship_pairs = list(
|
||||
set(
|
||||
[
|
||||
(
|
||||
relationship.split("__")[0].split(":")[0],
|
||||
relationship.split("__")[2].split(":")[0],
|
||||
)
|
||||
for relationship in raw_relationships
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
relationships = get_relationships_for_entity_type_pairs(
|
||||
db_session, relationship_pairs
|
||||
)
|
||||
|
||||
for relationship in relationships:
|
||||
relationship_type_map[relationship.source_entity_type_id_name][
|
||||
relationship.target_entity_type_id_name
|
||||
].append(relationship.id_name)
|
||||
|
||||
return relationship_type_map
|
||||
|
||||
|
||||
def normalize_entities(
|
||||
raw_entities_no_attributes: List[str],
|
||||
) -> NormalizedEntities:
|
||||
"""
|
||||
Match each entity against a list of normalized entities using fuzzy matching.
|
||||
Returns the best matching normalized entity for each input entity.
|
||||
|
||||
Args:
|
||||
raw_entities_no_attributes: List of entity strings to normalize, w/o attributes
|
||||
|
||||
Returns:
|
||||
List of normalized entity strings
|
||||
"""
|
||||
# TODO: this probably should move to a new Vespa schema for entity normalization
|
||||
# TODO: as is, this should be converted to a generator going through the entities
|
||||
# in large batches, to avoid memory issues
|
||||
|
||||
# Assume this is your predefined list of normalized entities
|
||||
norm_entities = _get_existing_normalized_entities(raw_entities_no_attributes)
|
||||
|
||||
norm_entity_semantic_to_id_map: dict[str, dict[str, str]] = defaultdict(dict)
|
||||
|
||||
for norm_entity_tuple in norm_entities:
|
||||
if norm_entity_tuple[1] is None:
|
||||
continue
|
||||
entity_type, norm_entity_semantic_name = _split_entity_type_v_name(
|
||||
norm_entity_tuple[1]
|
||||
)
|
||||
norm_entity_semantic_to_id_map[entity_type][norm_entity_semantic_name] = (
|
||||
_split_entity_type_v_name(norm_entity_tuple[0])[1]
|
||||
)
|
||||
|
||||
normalized_results: List[str] = []
|
||||
normalized_map: Dict[str, str | None] = {}
|
||||
threshold = 80 # Adjust threshold as needed
|
||||
|
||||
base_norm_entities = [norm_entity[0] for norm_entity in norm_entities]
|
||||
|
||||
for entity in raw_entities_no_attributes:
|
||||
entity_type, entity_name = entity.split(":")
|
||||
if entity_name == "*":
|
||||
normalized_results.append(entity)
|
||||
normalized_map[entity] = entity
|
||||
continue
|
||||
|
||||
# Find the best match and its score from norm_entities
|
||||
all_entity_match_possibilities = norm_entity_semantic_to_id_map[
|
||||
entity_type
|
||||
].keys()
|
||||
best_match, score = process.extractOne(
|
||||
entity_name, all_entity_match_possibilities, scorer=fuzz.partial_ratio
|
||||
)
|
||||
|
||||
if best_match not in base_norm_entities:
|
||||
best_match_id = norm_entity_semantic_to_id_map[entity_type][
|
||||
best_match
|
||||
] # replace semantic_id with the actual id
|
||||
else:
|
||||
best_match_id = best_match
|
||||
|
||||
if score >= threshold:
|
||||
normalized_results.append(f"{entity_type}:{best_match_id}")
|
||||
normalized_map[entity] = f"{entity_type}:{best_match_id}"
|
||||
else:
|
||||
# If no good match found, keep original
|
||||
normalized_map[entity] = entity
|
||||
|
||||
return NormalizedEntities(
|
||||
entities=normalized_results, entity_normalization_map=normalized_map
|
||||
)
|
||||
|
||||
|
||||
def normalize_entities_w_attributes_from_map(
|
||||
raw_entities_w_attributes: List[str],
|
||||
entity_normalization_map: Dict[str, Optional[str]],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Normalize entities with attributes using the entity normalization map.
|
||||
"""
|
||||
|
||||
normalized_entities_w_attributes: List[str] = []
|
||||
|
||||
for raw_entities_w_attribute in raw_entities_w_attributes:
|
||||
assert (
|
||||
len(raw_entities_w_attribute.split("--")) == 2
|
||||
), f"Invalid entity with attributes: {raw_entities_w_attribute}"
|
||||
raw_entity, attributes = raw_entities_w_attribute.split("--")
|
||||
normalized_entity = entity_normalization_map.get(raw_entity.strip())
|
||||
if normalized_entity is None:
|
||||
continue
|
||||
else:
|
||||
normalized_entities_w_attributes.append(
|
||||
f"{normalized_entity}--{raw_entities_w_attribute.split('--')[1].strip()}"
|
||||
)
|
||||
|
||||
return normalized_entities_w_attributes
|
||||
|
||||
|
||||
def normalize_relationships(
|
||||
raw_relationships: List[str], entity_normalization_map: Dict[str, Optional[str]]
|
||||
) -> NormalizedRelationships:
|
||||
"""
|
||||
Normalize relationships using entity mappings and relationship string matching.
|
||||
|
||||
Args:
|
||||
relationships: List of relationships in format "source__relation__target"
|
||||
entity_normalization_map: Mapping of raw entities to normalized ones (or None)
|
||||
|
||||
Returns:
|
||||
NormalizedRelationships containing normalized relationships and mapping
|
||||
"""
|
||||
# Placeholder for normalized relationship structure
|
||||
nor_relationships = _get_existing_normalized_relationships(raw_relationships)
|
||||
|
||||
normalized_rels: List[str] = []
|
||||
normalization_map: Dict[str, str | None] = {}
|
||||
|
||||
for raw_rel in raw_relationships:
|
||||
# 1. Split and normalize entities
|
||||
try:
|
||||
source, rel_string, target = raw_rel.split("__")
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid relationship format: {raw_rel}")
|
||||
|
||||
# Check if entities are in normalization map and not None
|
||||
norm_source = entity_normalization_map.get(source)
|
||||
norm_target = entity_normalization_map.get(target)
|
||||
|
||||
if norm_source is None or norm_target is None:
|
||||
normalization_map[raw_rel] = None
|
||||
continue
|
||||
|
||||
# 2. Find candidate normalized relationships
|
||||
candidate_rels = []
|
||||
norm_source_type = norm_source.split(":")[0]
|
||||
norm_target_type = norm_target.split(":")[0]
|
||||
if (
|
||||
norm_source_type in nor_relationships
|
||||
and norm_target_type in nor_relationships[norm_source_type]
|
||||
):
|
||||
candidate_rels = [
|
||||
rel.split("__")[1]
|
||||
for rel in nor_relationships[norm_source_type][norm_target_type]
|
||||
]
|
||||
|
||||
if not candidate_rels:
|
||||
normalization_map[raw_rel] = None
|
||||
continue
|
||||
|
||||
# 3. Encode and find best match
|
||||
strings_to_encode = [rel_string] + candidate_rels
|
||||
vectors = encode_string_batch(strings_to_encode)
|
||||
|
||||
# Get raw relation vector and candidate vectors
|
||||
raw_vector = vectors[0]
|
||||
candidate_vectors = vectors[1:]
|
||||
|
||||
# Calculate dot products
|
||||
dot_products = np.dot(candidate_vectors, raw_vector)
|
||||
best_match_idx = np.argmax(dot_products)
|
||||
|
||||
# Create normalized relationship
|
||||
norm_rel = f"{norm_source}__{candidate_rels[best_match_idx]}__{norm_target}"
|
||||
normalized_rels.append(norm_rel)
|
||||
normalization_map[raw_rel] = norm_rel
|
||||
|
||||
return NormalizedRelationships(
|
||||
relationships=normalized_rels, relationship_normalization_map=normalization_map
|
||||
)
|
||||
|
||||
|
||||
def normalize_terms(raw_terms: List[str]) -> NormalizedTerms:
|
||||
"""
|
||||
Normalize terms using semantic similarity matching.
|
||||
|
||||
Args:
|
||||
terms: List of terms to normalize
|
||||
|
||||
Returns:
|
||||
NormalizedTerms containing normalized terms and mapping
|
||||
"""
|
||||
# # Placeholder for normalized terms - this would typically come from a predefined list
|
||||
# normalized_term_list = [
|
||||
# "algorithm",
|
||||
# "database",
|
||||
# "software",
|
||||
# "programming",
|
||||
# # ... other normalized terms ...
|
||||
# ]
|
||||
|
||||
# normalized_terms: List[str] = []
|
||||
# normalization_map: Dict[str, str | None] = {}
|
||||
|
||||
# if not raw_terms:
|
||||
# return NormalizedTerms(terms=[], term_normalization_map={})
|
||||
|
||||
# # Encode all terms at once for efficiency
|
||||
# strings_to_encode = raw_terms + normalized_term_list
|
||||
# vectors = encode_string_batch(strings_to_encode)
|
||||
|
||||
# # Split vectors into query terms and candidate terms
|
||||
# query_vectors = vectors[:len(raw_terms)]
|
||||
# candidate_vectors = vectors[len(raw_terms):]
|
||||
|
||||
# # Calculate similarity for each term
|
||||
# for i, term in enumerate(raw_terms):
|
||||
# # Calculate dot products with all candidates
|
||||
# similarities = np.dot(candidate_vectors, query_vectors[i])
|
||||
# best_match_idx = np.argmax(similarities)
|
||||
# best_match_score = similarities[best_match_idx]
|
||||
|
||||
# # Use a threshold to determine if the match is good enough
|
||||
# if best_match_score > 0.7: # Adjust threshold as needed
|
||||
# normalized_term = normalized_term_list[best_match_idx]
|
||||
# normalized_terms.append(normalized_term)
|
||||
# normalization_map[term] = normalized_term
|
||||
# else:
|
||||
# # If no good match found, keep original
|
||||
# normalization_map[term] = None
|
||||
|
||||
# return NormalizedTerms(
|
||||
# terms=normalized_terms,
|
||||
# term_normalization_map=normalization_map
|
||||
# )
|
||||
|
||||
return NormalizedTerms(
|
||||
terms=raw_terms, term_normalization_map={term: term for term in raw_terms}
|
||||
)
|
||||
53
backend/onyx/kg/configuration.py
Normal file
53
backend/onyx/kg/configuration.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entity_type import populate_default_employee_account_information
|
||||
from onyx.db.entity_type import (
|
||||
populate_default_primary_grounded_entity_type_information,
|
||||
)
|
||||
from onyx.db.kg_config import get_kg_enablement
|
||||
from onyx.db.kg_config import KGConfigSettings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def populate_default_grounded_entity_types() -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if not get_kg_enablement(db_session):
|
||||
logger.error(
|
||||
"KG approach is not enabled, the entity types cannot be populated."
|
||||
)
|
||||
raise ValueError(
|
||||
"KG approach is not enabled, the entity types cannot be populated."
|
||||
)
|
||||
|
||||
populate_default_primary_grounded_entity_type_information(db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def populate_default_account_employee_definitions() -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if not get_kg_enablement(db_session):
|
||||
logger.error(
|
||||
"KG approach is not enabled, the entity types cannot be populated."
|
||||
)
|
||||
raise ValueError(
|
||||
"KG approach is not enabled, the entity types cannot be populated."
|
||||
)
|
||||
|
||||
populate_default_employee_account_information(db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def execute_kg_setting_tests(kg_config_settings: KGConfigSettings) -> None:
|
||||
if not kg_config_settings.KG_ENABLED:
|
||||
raise ValueError("KG is not enabled")
|
||||
if not kg_config_settings.KG_VENDOR:
|
||||
raise ValueError("KG_VENDOR is not set")
|
||||
if not kg_config_settings.KG_VENDOR_DOMAINS:
|
||||
raise ValueError("KG_VENDOR_DOMAINS is not set")
|
||||
1521
backend/onyx/kg/extractions/extraction_processing.py
Normal file
1521
backend/onyx/kg/extractions/extraction_processing.py
Normal file
File diff suppressed because it is too large
Load Diff
84
backend/onyx/kg/kg_default_entity_definitions.py
Normal file
84
backend/onyx/kg/kg_default_entity_definitions.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.kg.models import KGDefaultEntityDefinition
|
||||
from onyx.kg.models import KGGroundingType
|
||||
|
||||
|
||||
class KGDefaultPrimaryGroundedEntityDefinitions(BaseModel):
|
||||
|
||||
LINEAR: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A formal ticket about a product issue or improvement request.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="linear",
|
||||
)
|
||||
|
||||
FIREFLIES: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A phone call transcript between us (---vendor_name---) \
|
||||
and another account or individuals, or an internal meeting.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="fireflies",
|
||||
)
|
||||
|
||||
GONG: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A phone call transcript between us (---vendor_name---) \
|
||||
and another account or individuals, or an internal meeting.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="gong",
|
||||
)
|
||||
|
||||
SLACK: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A Slack conversation.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="slack",
|
||||
)
|
||||
|
||||
WEB: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A web page.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="web",
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A Google Drive document.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="google_drive",
|
||||
)
|
||||
|
||||
GMAIL: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="An email.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="gmail",
|
||||
)
|
||||
|
||||
JIRA: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A formal JIRA ticket about a product issue or improvement request.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
grounded_source_name="jira",
|
||||
)
|
||||
|
||||
|
||||
class KGDefaultAccountEmployeeDefinitions(BaseModel):
|
||||
|
||||
VENDOR: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="The Vendor ---vendor_name---, 'us'",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
active=False,
|
||||
grounded_source_name=None,
|
||||
)
|
||||
|
||||
ACCOUNT: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A company that was, is, or potentially could be a customer of the vendor \
|
||||
('us, ---vendor_name---'). Note that ---vendor_name--- can never be an ACCOUNT.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
active=False,
|
||||
grounded_source_name=None,
|
||||
)
|
||||
|
||||
EMPLOYEE: KGDefaultEntityDefinition = KGDefaultEntityDefinition(
|
||||
description="A person who speaks on \
|
||||
behalf of 'our' company (the VENDOR ---vendor_name---), NOT of another account. Therefore, employees of other companies \
|
||||
are NOT included here. If in doubt, do NOT extract.",
|
||||
grounding=KGGroundingType.GROUNDED,
|
||||
active=False,
|
||||
grounded_source_name=None,
|
||||
)
|
||||
219
backend/onyx/kg/models.py
Normal file
219
backend/onyx/kg/models.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class KGConfigSettings(BaseModel):
|
||||
KG_ENABLED: bool = False
|
||||
KG_VENDOR: str | None = None
|
||||
KG_VENDOR_DOMAINS: list[str] | None = None
|
||||
KG_IGNORE_EMAIL_DOMAINS: list[str] | None = None
|
||||
|
||||
|
||||
class KGConfigVars(str, Enum):
|
||||
KG_ENABLED = "KG_ENABLED"
|
||||
KG_VENDOR = "KG_VENDOR"
|
||||
KG_VENDOR_DOMAINS = "KG_VENDOR_DOMAINS"
|
||||
KG_IGNORE_EMAIL_DOMAINS = "KG_IGNORE_EMAIL_DOMAINS"
|
||||
|
||||
|
||||
class KGChunkFormat(BaseModel):
|
||||
connector_id: int | None = None
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
title: str
|
||||
content: str
|
||||
primary_owners: list[str]
|
||||
secondary_owners: list[str]
|
||||
source_type: str
|
||||
metadata: dict[str, str | list[str]] | None = None
|
||||
entities: Dict[str, int] = {}
|
||||
relationships: Dict[str, int] = {}
|
||||
terms: Dict[str, int] = {}
|
||||
deep_extraction: bool = False
|
||||
|
||||
|
||||
class KGChunkExtraction(BaseModel):
|
||||
connector_id: int
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
core_entity: str
|
||||
entities: list[str]
|
||||
relationships: list[str]
|
||||
terms: list[str]
|
||||
attributes: dict[str, str | list[str]]
|
||||
|
||||
|
||||
class KGChunkId(BaseModel):
|
||||
connector_id: int | None = None
|
||||
document_id: str
|
||||
chunk_id: int
|
||||
|
||||
|
||||
class KGRelationshipExtraction(BaseModel):
|
||||
relationship_str: str
|
||||
source_document_id: str
|
||||
|
||||
|
||||
class KGAggregatedExtractions(BaseModel):
|
||||
grounded_entities_document_ids: dict[str, str]
|
||||
entities: dict[str, int]
|
||||
relationships: dict[str, dict[str, int]]
|
||||
terms: dict[str, int]
|
||||
attributes: dict[str, dict[str, str | list[str]]]
|
||||
|
||||
|
||||
class KGBatchExtractionStats(BaseModel):
|
||||
connector_id: int | None = None
|
||||
succeeded: list[KGChunkId]
|
||||
failed: list[KGChunkId]
|
||||
aggregated_kg_extractions: KGAggregatedExtractions
|
||||
|
||||
|
||||
class ConnectorExtractionStats(BaseModel):
|
||||
connector_id: int
|
||||
num_succeeded: int
|
||||
num_failed: int
|
||||
num_processed: int
|
||||
|
||||
|
||||
class KGPerson(BaseModel):
|
||||
name: str
|
||||
company: str
|
||||
employee: bool
|
||||
|
||||
|
||||
class NormalizedEntities(BaseModel):
|
||||
entities: list[str]
|
||||
entity_normalization_map: dict[str, str | None]
|
||||
|
||||
|
||||
class NormalizedRelationships(BaseModel):
|
||||
relationships: list[str]
|
||||
relationship_normalization_map: dict[str, str | None]
|
||||
|
||||
|
||||
class NormalizedTerms(BaseModel):
|
||||
terms: list[str]
|
||||
term_normalization_map: dict[str, str | None]
|
||||
|
||||
|
||||
class KGClassificationContent(BaseModel):
|
||||
document_id: str
|
||||
classification_content: str
|
||||
source_type: str
|
||||
source_metadata: dict[str, Any] | None = None
|
||||
entity_type: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class KGEnrichedClassificationContent(KGClassificationContent):
|
||||
classification_enabled: bool
|
||||
classification_instructions: dict[str, Any]
|
||||
deep_extraction: bool
|
||||
|
||||
|
||||
class KGClassificationDecisions(BaseModel):
|
||||
document_id: str
|
||||
classification_decision: bool
|
||||
classification_class: str | None
|
||||
source_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class KGClassificationRule(BaseModel):
|
||||
description: str
|
||||
extration: bool
|
||||
|
||||
|
||||
class KGClassificationInstructions(BaseModel):
|
||||
classification_enabled: bool
|
||||
classification_options: str
|
||||
classification_class_definitions: dict[str, Dict[str, str | bool]]
|
||||
|
||||
|
||||
class KGExtractionInstructions(BaseModel):
|
||||
deep_extraction: bool
|
||||
active: bool
|
||||
|
||||
|
||||
class KGEntityTypeInstructions(BaseModel):
|
||||
classification_instructions: KGClassificationInstructions
|
||||
extraction_instructions: KGExtractionInstructions
|
||||
filter_instructions: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class KGEnhancedDocumentMetadata(BaseModel):
|
||||
entity_type: str | None
|
||||
document_attributes: dict[str, Any] | None
|
||||
deep_extraction: bool
|
||||
classification_enabled: bool
|
||||
classification_instructions: KGClassificationInstructions | None
|
||||
skip: bool
|
||||
|
||||
|
||||
class ContextPreparation(BaseModel):
|
||||
"""
|
||||
Context preparation format for the LLM KG extraction.
|
||||
"""
|
||||
|
||||
llm_context: str
|
||||
core_entity: str
|
||||
implied_entities: list[str]
|
||||
implied_relationships: list[str]
|
||||
implied_terms: list[str]
|
||||
|
||||
|
||||
class KGDocumentClassificationPrompt(BaseModel):
|
||||
"""
|
||||
Document classification prompt format for the LLM KG extraction.
|
||||
"""
|
||||
|
||||
llm_prompt: str | None
|
||||
|
||||
|
||||
class KGConnectorData(BaseModel):
|
||||
id: int
|
||||
source: str
|
||||
|
||||
|
||||
class KGStage(str, Enum):
|
||||
EXTRACTED = "extracted"
|
||||
NORMALIZED = "normalized"
|
||||
FAILED = "failed"
|
||||
SKIPPED = "skipped"
|
||||
NOT_STARTED = "not_started"
|
||||
EXTRACTING = "extracting"
|
||||
DO_NOT_EXTRACT = "do_not_extract"
|
||||
|
||||
|
||||
class KGDocumentEntitiesRelationshipsAttributes(BaseModel):
|
||||
kg_core_document_id_name: str
|
||||
implied_entities: set[str]
|
||||
implied_relationships: set[str]
|
||||
converted_relationships_to_attributes: dict[str, list[str]]
|
||||
company_participant_emails: set[str]
|
||||
account_participant_emails: set[str]
|
||||
converted_attributes_to_relationships: set[str]
|
||||
document_attributes: dict[str, Any] | None
|
||||
|
||||
|
||||
class KGGroundingType(str, Enum):
|
||||
UNGROUNDED = "ungrounded"
|
||||
GROUNDED = "grounded"
|
||||
|
||||
|
||||
class KGDefaultEntityDefinition(BaseModel):
|
||||
description: str
|
||||
grounding: KGGroundingType
|
||||
active: bool = False
|
||||
grounded_source_name: str | None
|
||||
attributes: dict = {}
|
||||
entity_values: dict = {}
|
||||
|
||||
|
||||
class KGEntityInformation(BaseModel):
|
||||
entity_type: str
|
||||
entity_name: str
|
||||
occurences: int
|
||||
1
backend/onyx/kg/resets/reset_connector.py
Normal file
1
backend/onyx/kg/resets/reset_connector.py
Normal file
@@ -0,0 +1 @@
|
||||
# TODO: Implement this
|
||||
21
backend/onyx/kg/resets/reset_extractions.py
Normal file
21
backend/onyx/kg/resets/reset_extractions.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from onyx.db.document import update_document_kg_stages
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KGEntityExtractionStaging
|
||||
from onyx.db.models import KGRelationshipExtractionStaging
|
||||
from onyx.db.models import KGRelationshipTypeExtractionStaging
|
||||
from onyx.kg.models import KGStage
|
||||
|
||||
|
||||
def reset_extraction_kg_index() -> None:
|
||||
"""
|
||||
Resets the knowledge graph index.
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.query(KGRelationshipExtractionStaging).delete()
|
||||
db_session.query(KGEntityExtractionStaging).delete()
|
||||
db_session.query(KGRelationshipTypeExtractionStaging).delete()
|
||||
db_session.commit()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
update_document_kg_stages(db_session, KGStage.EXTRACTED, KGStage.NOT_STARTED)
|
||||
db_session.commit()
|
||||
26
backend/onyx/kg/resets/reset_index.py
Normal file
26
backend/onyx/kg/resets/reset_index.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from onyx.db.document import reset_all_document_kg_stages
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KGEntity
|
||||
from onyx.db.models import KGEntityExtractionStaging
|
||||
from onyx.db.models import KGRelationship
|
||||
from onyx.db.models import KGRelationshipExtractionStaging
|
||||
from onyx.db.models import KGRelationshipType
|
||||
from onyx.db.models import KGRelationshipTypeExtractionStaging
|
||||
|
||||
|
||||
def reset_full_kg_index() -> None:
|
||||
"""
|
||||
Resets the knowledge graph index.
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.query(KGRelationship).delete()
|
||||
db_session.query(KGRelationshipType).delete()
|
||||
db_session.query(KGEntity).delete()
|
||||
db_session.query(KGRelationshipExtractionStaging).delete()
|
||||
db_session.query(KGEntityExtractionStaging).delete()
|
||||
db_session.query(KGRelationshipTypeExtractionStaging).delete()
|
||||
db_session.commit()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
reset_all_document_kg_stages(db_session)
|
||||
db_session.commit()
|
||||
22
backend/onyx/kg/resets/reset_normalizations.py
Normal file
22
backend/onyx/kg/resets/reset_normalizations.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from onyx.db.document import update_document_kg_stages
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KGEntity
|
||||
from onyx.db.models import KGRelationship
|
||||
from onyx.db.models import KGRelationshipType
|
||||
from onyx.kg.models import KGStage
|
||||
|
||||
|
||||
def reset_normalization_kg_index() -> None:
|
||||
"""
|
||||
Resets the knowledge graph index.
|
||||
"""
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.query(KGRelationship).delete()
|
||||
db_session.query(KGEntity).delete()
|
||||
db_session.query(KGRelationshipType).delete()
|
||||
db_session.commit()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
update_document_kg_stages(db_session, KGStage.NORMALIZED, KGStage.EXTRACTED)
|
||||
db_session.commit()
|
||||
23
backend/onyx/kg/utils/embeddings.py
Normal file
23
backend/onyx/kg/utils/embeddings.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbedTextType
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
|
||||
def encode_string_batch(strings: List[str]) -> np.ndarray:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=current_search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
# Get embeddings while session is still open
|
||||
embedding = model.encode(strings, text_type=EmbedTextType.QUERY)
|
||||
return np.array(embedding)
|
||||
410
backend/onyx/kg/utils/extraction_utils.py
Normal file
410
backend/onyx/kg/utils/extraction_utils.py
Normal file
@@ -0,0 +1,410 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
|
||||
from onyx.configs.constants import OnyxCallTypes
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import get_kg_entity_by_document
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.kg_config import KGConfigSettings
|
||||
from onyx.db.models import Document
|
||||
from onyx.kg.models import KGChunkFormat
|
||||
from onyx.kg.models import KGClassificationContent
|
||||
from onyx.kg.models import (
|
||||
KGDocumentClassificationPrompt,
|
||||
)
|
||||
from onyx.kg.models import KGDocumentEntitiesRelationshipsAttributes
|
||||
from onyx.kg.models import KGEnhancedDocumentMetadata
|
||||
from onyx.kg.utils.formatting_utils import generalize_entities
|
||||
from onyx.kg.utils.formatting_utils import kg_email_processing
|
||||
from onyx.prompts.kg_prompts import CALL_CHUNK_PREPROCESSING_PROMPT
|
||||
from onyx.prompts.kg_prompts import CALL_DOCUMENT_CLASSIFICATION_PROMPT
|
||||
from onyx.prompts.kg_prompts import GENERAL_CHUNK_PREPROCESSING_PROMPT
|
||||
|
||||
|
||||
def _update_implied_entities_relationships(
|
||||
kg_core_document_id_name: str,
|
||||
owner_list: list[str],
|
||||
implied_entities: set[str],
|
||||
implied_relationships: set[str],
|
||||
company_participant_emails: set[str],
|
||||
account_participant_emails: set[str],
|
||||
relationship_type: str,
|
||||
kg_config_settings: KGConfigSettings,
|
||||
converted_relationships_to_attributes: dict[str, list[str]],
|
||||
) -> tuple[set[str], set[str], set[str], set[str], dict[str, list[str]]]:
|
||||
|
||||
for owner in owner_list or []:
|
||||
if is_email(owner):
|
||||
(
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
) = kg_process_person(
|
||||
owner,
|
||||
kg_core_document_id_name,
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
relationship_type,
|
||||
kg_config_settings,
|
||||
)
|
||||
else:
|
||||
converted_relationships_to_attributes[relationship_type].append(owner)
|
||||
|
||||
return (
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
converted_relationships_to_attributes,
|
||||
)
|
||||
|
||||
|
||||
def kg_document_entities_relationships_attribute_generation(
|
||||
document: Document,
|
||||
doc_metadata: KGEnhancedDocumentMetadata,
|
||||
active_entities: list[str],
|
||||
kg_config_settings: KGConfigSettings,
|
||||
) -> KGDocumentEntitiesRelationshipsAttributes:
|
||||
"""
|
||||
Generate entities, relationships, and attributes for a document.
|
||||
"""
|
||||
|
||||
# Get document entity type from the KGEnhancedDocumentMetadata
|
||||
document_entity_type = doc_metadata.entity_type
|
||||
assert document_entity_type is not None
|
||||
|
||||
# Get additional document attributes from the KGEnhancedDocumentMetadata
|
||||
document_attributes = doc_metadata.document_attributes
|
||||
|
||||
implied_entities: set[str] = set()
|
||||
implied_relationships: set[str] = (
|
||||
set()
|
||||
) # 'Relationships' that will be captured as KG relationships
|
||||
converted_relationships_to_attributes: dict[str, list[str]] = defaultdict(
|
||||
list
|
||||
) # 'Relationships' that will be captured as KG entity attributes
|
||||
|
||||
converted_attributes_to_relationships: set[str] = (
|
||||
set()
|
||||
) # Attributes that should be captures as entities and then relationships (Account = ...)
|
||||
|
||||
company_participant_emails: set[str] = (
|
||||
set()
|
||||
) # Quantity needed for call processing - participants from vendor
|
||||
account_participant_emails: set[str] = (
|
||||
set()
|
||||
) # Quantity needed for call processing - external participants
|
||||
|
||||
# Chunk treatment variables
|
||||
|
||||
document_is_from_call = document_entity_type.lower() in [
|
||||
call_type.value.lower() for call_type in OnyxCallTypes
|
||||
]
|
||||
|
||||
# Get core entity
|
||||
|
||||
document_id = document.id
|
||||
primary_owners = document.primary_owners
|
||||
secondary_owners = document.secondary_owners
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
kg_core_document = get_kg_entity_by_document(db_session, document_id)
|
||||
|
||||
if kg_core_document:
|
||||
kg_core_document_id_name = kg_core_document.id_name
|
||||
else:
|
||||
kg_core_document_id_name = f"{document_entity_type.upper()}:{document_id}"
|
||||
|
||||
# Get implied entities and relationships from primary/secondary owners
|
||||
|
||||
if document_is_from_call:
|
||||
(
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
converted_relationships_to_attributes,
|
||||
) = _update_implied_entities_relationships(
|
||||
kg_core_document_id_name,
|
||||
owner_list=(primary_owners or []) + (secondary_owners or []),
|
||||
implied_entities=implied_entities,
|
||||
implied_relationships=implied_relationships,
|
||||
company_participant_emails=company_participant_emails,
|
||||
account_participant_emails=account_participant_emails,
|
||||
relationship_type="participates_in",
|
||||
kg_config_settings=kg_config_settings,
|
||||
converted_relationships_to_attributes=converted_relationships_to_attributes,
|
||||
)
|
||||
else:
|
||||
(
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
converted_relationships_to_attributes,
|
||||
) = _update_implied_entities_relationships(
|
||||
kg_core_document_id_name,
|
||||
owner_list=primary_owners or [],
|
||||
implied_entities=implied_entities,
|
||||
implied_relationships=implied_relationships,
|
||||
company_participant_emails=company_participant_emails,
|
||||
account_participant_emails=account_participant_emails,
|
||||
relationship_type="leads",
|
||||
kg_config_settings=kg_config_settings,
|
||||
converted_relationships_to_attributes=converted_relationships_to_attributes,
|
||||
)
|
||||
|
||||
(
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
converted_relationships_to_attributes,
|
||||
) = _update_implied_entities_relationships(
|
||||
kg_core_document_id_name,
|
||||
owner_list=secondary_owners or [],
|
||||
implied_entities=implied_entities,
|
||||
implied_relationships=implied_relationships,
|
||||
company_participant_emails=company_participant_emails,
|
||||
account_participant_emails=account_participant_emails,
|
||||
relationship_type="participates_in",
|
||||
kg_config_settings=kg_config_settings,
|
||||
converted_relationships_to_attributes=converted_relationships_to_attributes,
|
||||
)
|
||||
|
||||
if document_attributes is not None:
|
||||
cleaned_document_attributes = document_attributes.copy()
|
||||
for attribute, value in document_attributes.items():
|
||||
if attribute.lower() in [x.lower() for x in active_entities]:
|
||||
converted_attributes_to_relationships.add(attribute)
|
||||
if isinstance(value, str):
|
||||
implied_entity = f"{attribute.upper()}:{value.capitalize()}"
|
||||
implied_entities.add(implied_entity)
|
||||
implied_relationships.add(
|
||||
f"{implied_entity}__is_{attribute.lower()}_of__{kg_core_document_id_name}"
|
||||
)
|
||||
|
||||
implied_entity = f"{attribute.upper()}:*"
|
||||
implied_entities.add(implied_entity)
|
||||
implied_relationships.add(
|
||||
f"{implied_entity}__is_{attribute.lower()}_of__{kg_core_document_id_name}"
|
||||
)
|
||||
|
||||
implied_entity = f"{attribute.upper()}:*"
|
||||
implied_entities.add(implied_entity)
|
||||
implied_relationships.add(
|
||||
f"{implied_entity}__is_{attribute.lower()}_of__{document_entity_type.upper()}:*"
|
||||
)
|
||||
|
||||
implied_entity = f"{attribute.upper()}:{value.capitalize()}"
|
||||
implied_entities.add(implied_entity)
|
||||
implied_relationships.add(
|
||||
f"{implied_entity}__is_{attribute.lower()}_of__{document_entity_type.upper()}:*"
|
||||
)
|
||||
|
||||
cleaned_document_attributes.pop(attribute)
|
||||
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
implied_entity = f"{attribute.upper()}:{item.capitalize()}"
|
||||
implied_entities.add(implied_entity)
|
||||
implied_relationships.add(
|
||||
f"{implied_entity}__is_{attribute.lower()}_of__{kg_core_document_id_name}"
|
||||
)
|
||||
cleaned_document_attributes.pop(attribute)
|
||||
if attribute.lower().endswith("_id") or attribute.endswith("Id"):
|
||||
cleaned_document_attributes.pop(attribute)
|
||||
else:
|
||||
cleaned_document_attributes = None
|
||||
|
||||
return KGDocumentEntitiesRelationshipsAttributes(
|
||||
kg_core_document_id_name=kg_core_document_id_name,
|
||||
implied_entities=implied_entities,
|
||||
implied_relationships=implied_relationships,
|
||||
company_participant_emails=company_participant_emails,
|
||||
account_participant_emails=account_participant_emails,
|
||||
converted_relationships_to_attributes=converted_relationships_to_attributes,
|
||||
converted_attributes_to_relationships=converted_attributes_to_relationships,
|
||||
document_attributes=cleaned_document_attributes,
|
||||
)
|
||||
|
||||
|
||||
def _prepare_llm_document_content_call(
|
||||
document_classification_content: KGClassificationContent,
|
||||
category_list: str,
|
||||
category_definition_string: str,
|
||||
kg_config_settings: KGConfigSettings,
|
||||
) -> KGDocumentClassificationPrompt:
|
||||
"""
|
||||
Calls - prepare prompt for the LLM classification.
|
||||
"""
|
||||
|
||||
prompt = CALL_DOCUMENT_CLASSIFICATION_PROMPT.format(
|
||||
beginning_of_call_content=document_classification_content.classification_content,
|
||||
category_list=category_list,
|
||||
category_options=category_definition_string,
|
||||
vendor=kg_config_settings.KG_VENDOR,
|
||||
)
|
||||
|
||||
return KGDocumentClassificationPrompt(
|
||||
llm_prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
def kg_process_person(
|
||||
person: str,
|
||||
core_document_id_name: str,
|
||||
implied_entities: set[str],
|
||||
implied_relationships: set[str],
|
||||
company_participant_emails: set[str],
|
||||
account_participant_emails: set[str],
|
||||
relationship_type: str,
|
||||
kg_config_settings: KGConfigSettings,
|
||||
) -> tuple[set[str], set[str], set[str], set[str]]:
|
||||
"""
|
||||
Process a single owner and return updated sets with entities and relationships.
|
||||
|
||||
Returns:
|
||||
tuple containing (implied_entities, implied_relationships, company_participant_emails, account_participant_emails)
|
||||
"""
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
kg_config_settings = get_kg_config_settings(db_session)
|
||||
if not kg_config_settings.KG_ENABLED:
|
||||
raise ValueError("KG is not enabled")
|
||||
|
||||
assert isinstance(kg_config_settings.KG_IGNORE_EMAIL_DOMAINS, list)
|
||||
|
||||
kg_person = kg_email_processing(person, kg_config_settings)
|
||||
if any(
|
||||
domain.lower() in kg_person.company.lower()
|
||||
for domain in kg_config_settings.KG_IGNORE_EMAIL_DOMAINS
|
||||
):
|
||||
return (
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
)
|
||||
|
||||
if kg_person.employee:
|
||||
company_participant_emails = company_participant_emails | {
|
||||
f"{kg_person.name} -- ({kg_person.company})"
|
||||
}
|
||||
if kg_person.name not in implied_entities:
|
||||
generalized_target_entity = list(
|
||||
generalize_entities([core_document_id_name])
|
||||
)[0]
|
||||
|
||||
implied_entities = implied_entities | {f"EMPLOYEE:{kg_person.name}"}
|
||||
implied_relationships = implied_relationships | {
|
||||
f"EMPLOYEE:{kg_person.name}__{relationship_type}__{core_document_id_name}",
|
||||
f"EMPLOYEE:{kg_person.name}__{relationship_type}__{generalized_target_entity}",
|
||||
f"EMPLOYEE:*__{relationship_type}__{core_document_id_name}",
|
||||
f"EMPLOYEE:*__{relationship_type}__{generalized_target_entity}",
|
||||
}
|
||||
if kg_person.company not in implied_entities:
|
||||
implied_entities = implied_entities | {f"VENDOR:{kg_person.company}"}
|
||||
implied_relationships = implied_relationships | {
|
||||
f"VENDOR:{kg_person.company}__{relationship_type}__{core_document_id_name}",
|
||||
f"VENDOR:{kg_person.company}__{relationship_type}__{generalized_target_entity}",
|
||||
}
|
||||
|
||||
else:
|
||||
account_participant_emails = account_participant_emails | {
|
||||
f"{kg_person.name} -- ({kg_person.company})"
|
||||
}
|
||||
if kg_person.company not in implied_entities:
|
||||
implied_entities = implied_entities | {
|
||||
f"ACCOUNT:{kg_person.company}",
|
||||
"ACCOUNT:*",
|
||||
}
|
||||
implied_relationships = implied_relationships | {
|
||||
f"ACCOUNT:{kg_person.company}__{relationship_type}__{core_document_id_name}",
|
||||
f"ACCOUNT:*__{relationship_type}__{core_document_id_name}",
|
||||
}
|
||||
|
||||
generalized_target_entity = list(
|
||||
generalize_entities([core_document_id_name])
|
||||
)[0]
|
||||
|
||||
implied_relationships = implied_relationships | {
|
||||
f"ACCOUNT:*__{relationship_type}__{generalized_target_entity}",
|
||||
f"ACCOUNT:{kg_person.company}__{relationship_type}__{generalized_target_entity}",
|
||||
}
|
||||
|
||||
return (
|
||||
implied_entities,
|
||||
implied_relationships,
|
||||
company_participant_emails,
|
||||
account_participant_emails,
|
||||
)
|
||||
|
||||
|
||||
def prepare_llm_content_extraction(
|
||||
chunk: KGChunkFormat,
|
||||
company_participant_emails: set[str],
|
||||
account_participant_emails: set[str],
|
||||
kg_config_settings: KGConfigSettings,
|
||||
) -> str:
|
||||
|
||||
chunk_is_from_call = chunk.source_type.lower() in [
|
||||
call_type.value.lower() for call_type in OnyxCallTypes
|
||||
]
|
||||
|
||||
if chunk_is_from_call:
|
||||
|
||||
llm_context = CALL_CHUNK_PREPROCESSING_PROMPT.format(
|
||||
participant_string=company_participant_emails,
|
||||
account_participant_string=account_participant_emails,
|
||||
vendor=kg_config_settings.KG_VENDOR,
|
||||
content=chunk.content,
|
||||
)
|
||||
else:
|
||||
llm_context = GENERAL_CHUNK_PREPROCESSING_PROMPT.format(
|
||||
content=chunk.content,
|
||||
vendor=kg_config_settings.KG_VENDOR,
|
||||
)
|
||||
|
||||
return llm_context
|
||||
|
||||
|
||||
def prepare_llm_document_content(
|
||||
document_classification_content: KGClassificationContent,
|
||||
category_list: str,
|
||||
category_definitions: dict[str, Dict[str, str | bool]],
|
||||
kg_config_settings: KGConfigSettings,
|
||||
) -> KGDocumentClassificationPrompt:
|
||||
"""
|
||||
Prepare the content for the extraction classification.
|
||||
"""
|
||||
|
||||
category_definition_string = ""
|
||||
for category, category_data in category_definitions.items():
|
||||
category_definition_string += f"{category}: {category_data['description']}\n"
|
||||
|
||||
if document_classification_content.source_type.lower() in [
|
||||
call_type.value.lower() for call_type in OnyxCallTypes
|
||||
]:
|
||||
return _prepare_llm_document_content_call(
|
||||
document_classification_content,
|
||||
category_list,
|
||||
category_definition_string,
|
||||
kg_config_settings,
|
||||
)
|
||||
|
||||
else:
|
||||
return KGDocumentClassificationPrompt(
|
||||
llm_prompt=None,
|
||||
)
|
||||
|
||||
|
||||
def is_email(email: str) -> bool:
|
||||
"""
|
||||
Check if a string is a valid email address.
|
||||
"""
|
||||
return re.match(r"[^@]+@[^@]+\.[^@]+", email) is not None
|
||||
139
backend/onyx/kg/utils/formatting_utils.py
Normal file
139
backend/onyx/kg/utils/formatting_utils.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from onyx.db.kg_config import KGConfigSettings
|
||||
from onyx.kg.models import KGAggregatedExtractions
|
||||
from onyx.kg.models import KGPerson
|
||||
|
||||
|
||||
def format_entity(entity: str) -> str:
|
||||
if len(entity.split(":")) == 2:
|
||||
entity_type, entity_name = entity.split(":")
|
||||
return f"{entity_type.upper()}:{entity_name.title()}"
|
||||
else:
|
||||
return entity
|
||||
|
||||
|
||||
def format_relationship(relationship: str) -> str:
|
||||
source_node, relationship_type, target_node = relationship.split("__")
|
||||
return (
|
||||
f"{format_entity(source_node)}__"
|
||||
f"{relationship_type.lower()}__"
|
||||
f"{format_entity(target_node)}"
|
||||
)
|
||||
|
||||
|
||||
def format_relationship_type(relationship_type: str) -> str:
|
||||
source_node_type, relationship_type, target_node_type = relationship_type.split(
|
||||
"__"
|
||||
)
|
||||
return (
|
||||
f"{source_node_type.upper()}__"
|
||||
f"{relationship_type.lower()}__"
|
||||
f"{target_node_type.upper()}"
|
||||
)
|
||||
|
||||
|
||||
def generate_relationship_type(relationship: str) -> str:
|
||||
source_node, relationship_type, target_node = relationship.split("__")
|
||||
return (
|
||||
f"{source_node.split(':')[0].upper()}__"
|
||||
f"{relationship_type.lower()}__"
|
||||
f"{target_node.split(':')[0].upper()}"
|
||||
)
|
||||
|
||||
|
||||
def aggregate_kg_extractions(
|
||||
connector_aggregated_kg_extractions_list: list[KGAggregatedExtractions],
|
||||
) -> KGAggregatedExtractions:
|
||||
aggregated_kg_extractions = KGAggregatedExtractions(
|
||||
grounded_entities_document_ids=defaultdict(str),
|
||||
entities=defaultdict(int),
|
||||
relationships=defaultdict(lambda: defaultdict(int)),
|
||||
terms=defaultdict(int),
|
||||
attributes=defaultdict(dict),
|
||||
)
|
||||
for connector_aggregated_kg_extractions in connector_aggregated_kg_extractions_list:
|
||||
for (
|
||||
grounded_entity,
|
||||
document_id,
|
||||
) in connector_aggregated_kg_extractions.grounded_entities_document_ids.items():
|
||||
aggregated_kg_extractions.grounded_entities_document_ids[
|
||||
grounded_entity
|
||||
] = document_id
|
||||
|
||||
for entity, count in connector_aggregated_kg_extractions.entities.items():
|
||||
if entity not in aggregated_kg_extractions.entities:
|
||||
aggregated_kg_extractions.entities[entity] = count
|
||||
else:
|
||||
aggregated_kg_extractions.entities[entity] += count
|
||||
for (
|
||||
relationship,
|
||||
relationship_data,
|
||||
) in connector_aggregated_kg_extractions.relationships.items():
|
||||
for source_document_id, count in relationship_data.items():
|
||||
if relationship not in aggregated_kg_extractions.relationships:
|
||||
aggregated_kg_extractions.relationships[relationship] = defaultdict(
|
||||
int
|
||||
)
|
||||
aggregated_kg_extractions.relationships[relationship][
|
||||
source_document_id
|
||||
] += count
|
||||
for term, count in connector_aggregated_kg_extractions.terms.items():
|
||||
if term not in aggregated_kg_extractions.terms:
|
||||
aggregated_kg_extractions.terms[term] = count
|
||||
else:
|
||||
aggregated_kg_extractions.terms[term] += count
|
||||
|
||||
return aggregated_kg_extractions
|
||||
|
||||
|
||||
def kg_email_processing(email: str, kg_config_settings: KGConfigSettings) -> KGPerson:
|
||||
"""
|
||||
Process the email.
|
||||
"""
|
||||
name, company_domain = email.split("@")
|
||||
assert isinstance(company_domain, str)
|
||||
assert isinstance(kg_config_settings.KG_VENDOR_DOMAINS, list)
|
||||
assert isinstance(kg_config_settings.KG_VENDOR, str)
|
||||
|
||||
employee = any(
|
||||
domain in company_domain for domain in kg_config_settings.KG_VENDOR_DOMAINS
|
||||
)
|
||||
if employee:
|
||||
company = kg_config_settings.KG_VENDOR
|
||||
else:
|
||||
company = company_domain.capitalize()
|
||||
|
||||
return KGPerson(name=name, company=company, employee=employee)
|
||||
|
||||
|
||||
def generalize_entities(entities: list[str]) -> set[str]:
|
||||
"""
|
||||
Generalize entities to their superclass.
|
||||
"""
|
||||
return set([f"{entity.split(':')[0]}:*" for entity in entities])
|
||||
|
||||
|
||||
def generalize_relationships(relationships: list[str]) -> set[str]:
|
||||
"""
|
||||
Generalize relationships to their superclass.
|
||||
"""
|
||||
generalized_relationships: set[str] = set()
|
||||
for relationship in relationships:
|
||||
assert (
|
||||
len(relationship.split("__")) == 3
|
||||
), "Relationship is not in the correct format"
|
||||
source_entity, relationship_type, target_entity = relationship.split("__")
|
||||
generalized_source_entity = list(generalize_entities([source_entity]))[0]
|
||||
generalized_target_entity = list(generalize_entities([target_entity]))[0]
|
||||
generalized_relationships.add(
|
||||
f"{generalized_source_entity}__{relationship_type}__{target_entity}"
|
||||
)
|
||||
generalized_relationships.add(
|
||||
f"{source_entity}__{relationship_type}__{generalized_target_entity}"
|
||||
)
|
||||
generalized_relationships.add(
|
||||
f"{generalized_source_entity}__{relationship_type}__{generalized_target_entity}"
|
||||
)
|
||||
|
||||
return generalized_relationships
|
||||
244
backend/onyx/kg/vespa/vespa_interactions.py
Normal file
244
backend/onyx/kg/vespa/vespa_interactions.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from onyx.configs.constants import OnyxCallTypes
|
||||
from onyx.db.kg_config import KGConfigSettings
|
||||
from onyx.document_index.vespa.chunk_retrieval import _get_chunks_via_visit_api
|
||||
from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest
|
||||
from onyx.document_index.vespa.index import IndexFilters
|
||||
from onyx.kg.models import KGChunkFormat
|
||||
from onyx.kg.models import KGClassificationContent
|
||||
from onyx.kg.utils.formatting_utils import kg_email_processing
|
||||
|
||||
|
||||
def get_document_classification_content_for_kg_processing(
|
||||
document_ids: list[str],
|
||||
source: str,
|
||||
index_name: str,
|
||||
kg_config_settings: KGConfigSettings,
|
||||
batch_size: int = 8,
|
||||
num_classification_chunks: int = 3,
|
||||
entity_type: str | None = None,
|
||||
) -> Generator[list[KGClassificationContent], None, None]:
|
||||
"""
|
||||
Generates the content used for initial classification of a document from
|
||||
the first num_classification_chunks chunks.
|
||||
"""
|
||||
|
||||
classification_content_list: list[KGClassificationContent] = []
|
||||
|
||||
for i in range(0, len(document_ids), batch_size):
|
||||
batch_document_ids = document_ids[i : i + batch_size]
|
||||
for document_id in batch_document_ids:
|
||||
# ... existing code for getting chunks and processing ...
|
||||
first_num_classification_chunks: list[dict] = _get_chunks_via_visit_api(
|
||||
chunk_request=VespaChunkRequest(
|
||||
document_id=document_id,
|
||||
max_chunk_ind=num_classification_chunks - 1,
|
||||
min_chunk_ind=0,
|
||||
),
|
||||
index_name=index_name,
|
||||
filters=IndexFilters(access_control_list=None),
|
||||
field_names=[
|
||||
"document_id",
|
||||
"chunk_id",
|
||||
"title",
|
||||
"content",
|
||||
"metadata",
|
||||
"source_type",
|
||||
"primary_owners",
|
||||
"secondary_owners",
|
||||
],
|
||||
get_large_chunks=False,
|
||||
)
|
||||
|
||||
first_num_classification_chunks = sorted(
|
||||
first_num_classification_chunks, key=lambda x: x["fields"]["chunk_id"]
|
||||
)[:num_classification_chunks]
|
||||
|
||||
classification_content = _get_classification_content_from_chunks(
|
||||
first_num_classification_chunks,
|
||||
kg_config_settings,
|
||||
)
|
||||
|
||||
metadata = first_num_classification_chunks[0]["fields"]["metadata"]
|
||||
if isinstance(metadata, str):
|
||||
metadata = json.loads(metadata)
|
||||
assert isinstance(metadata, dict) or metadata is None
|
||||
|
||||
classification_content_list.append(
|
||||
KGClassificationContent(
|
||||
document_id=document_id,
|
||||
classification_content=classification_content,
|
||||
source_type=first_num_classification_chunks[0]["fields"][
|
||||
"source_type"
|
||||
],
|
||||
source_metadata=metadata,
|
||||
entity_type=entity_type,
|
||||
)
|
||||
)
|
||||
|
||||
# Yield the batch of classification content
|
||||
if classification_content_list:
|
||||
yield classification_content_list
|
||||
classification_content_list = []
|
||||
|
||||
# Yield any remaining items
|
||||
if classification_content_list:
|
||||
yield classification_content_list
|
||||
|
||||
|
||||
def get_document_chunks_for_kg_processing(
|
||||
document_id: str,
|
||||
deep_extraction: bool,
|
||||
index_name: str,
|
||||
batch_size: int = 8,
|
||||
) -> Generator[list[KGChunkFormat], None, None]:
|
||||
"""
|
||||
Retrieves chunks from Vespa for the given document IDs and converts them to KGChunks.
|
||||
|
||||
Args:
|
||||
document_ids (list[str]): List of document IDs to fetch chunks for
|
||||
index_name (str): Name of the Vespa index
|
||||
tenant_id (str): ID of the tenant
|
||||
|
||||
Yields:
|
||||
list[KGChunk]: Batches of chunks ready for KG processing
|
||||
"""
|
||||
|
||||
current_batch: list[KGChunkFormat] = []
|
||||
|
||||
# get all chunks for the document
|
||||
chunks = _get_chunks_via_visit_api(
|
||||
chunk_request=VespaChunkRequest(document_id=document_id),
|
||||
index_name=index_name,
|
||||
filters=IndexFilters(access_control_list=None),
|
||||
field_names=[
|
||||
"document_id",
|
||||
"chunk_id",
|
||||
"title",
|
||||
"content",
|
||||
"metadata",
|
||||
"primary_owners",
|
||||
"secondary_owners",
|
||||
"source_type",
|
||||
"kg_entities",
|
||||
"kg_relationships",
|
||||
"kg_terms",
|
||||
],
|
||||
get_large_chunks=False,
|
||||
)
|
||||
|
||||
# Convert Vespa chunks to KGChunks
|
||||
# kg_chunks: list[KGChunkFormat] = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
fields = chunk["fields"]
|
||||
if isinstance(fields.get("metadata", {}), str):
|
||||
fields["metadata"] = json.loads(fields["metadata"])
|
||||
current_batch.append(
|
||||
KGChunkFormat(
|
||||
connector_id=None, # We may need to adjust this
|
||||
document_id=fields.get("document_id"),
|
||||
chunk_id=fields.get("chunk_id"),
|
||||
primary_owners=fields.get("primary_owners", []),
|
||||
secondary_owners=fields.get("secondary_owners", []),
|
||||
source_type=fields.get("source_type", ""),
|
||||
title=fields.get("title", ""),
|
||||
content=fields.get("content", ""),
|
||||
metadata=fields.get("metadata", {}),
|
||||
entities=fields.get("kg_entities", {}),
|
||||
relationships=fields.get("kg_relationships", {}),
|
||||
terms=fields.get("kg_terms", {}),
|
||||
deep_extraction=deep_extraction,
|
||||
)
|
||||
)
|
||||
|
||||
if len(current_batch) >= batch_size:
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
|
||||
# Yield any remaining chunks
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
|
||||
def _get_classification_content_from_call_chunks(
|
||||
first_num_classification_chunks: list[dict],
|
||||
kg_config_settings: KGConfigSettings,
|
||||
) -> str:
|
||||
"""
|
||||
Creates a KGClassificationContent object from a list of call chunks.
|
||||
"""
|
||||
|
||||
assert isinstance(kg_config_settings.KG_IGNORE_EMAIL_DOMAINS, list)
|
||||
|
||||
primary_owners = first_num_classification_chunks[0]["fields"].get(
|
||||
"primary_owners", []
|
||||
)
|
||||
secondary_owners = first_num_classification_chunks[0]["fields"].get(
|
||||
"secondary_owners", []
|
||||
)
|
||||
|
||||
company_participant_emails = set()
|
||||
account_participant_emails = set()
|
||||
|
||||
for owner in primary_owners + secondary_owners:
|
||||
kg_owner = kg_email_processing(owner, kg_config_settings)
|
||||
if any(
|
||||
domain.lower() in kg_owner.company.lower()
|
||||
for domain in kg_config_settings.KG_IGNORE_EMAIL_DOMAINS
|
||||
):
|
||||
continue
|
||||
|
||||
if kg_owner.employee:
|
||||
company_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
|
||||
else:
|
||||
account_participant_emails.add(f"{kg_owner.name} -- ({kg_owner.company})")
|
||||
|
||||
participant_string = "\n - " + "\n - ".join(company_participant_emails)
|
||||
account_participant_string = "\n - " + "\n - ".join(account_participant_emails)
|
||||
|
||||
title_string = first_num_classification_chunks[0]["fields"]["title"]
|
||||
content_string = "\n".join(
|
||||
[
|
||||
chunk_content["fields"]["content"]
|
||||
for chunk_content in first_num_classification_chunks
|
||||
]
|
||||
)
|
||||
|
||||
classification_content = f"{title_string}\n\nVendor Participants:\n{participant_string}\n\n\
|
||||
Other Participants:\n{account_participant_string}\n\nBeginning of Call:\n{content_string}"
|
||||
|
||||
return classification_content
|
||||
|
||||
|
||||
def _get_classification_content_from_chunks(
|
||||
first_num_classification_chunks: list[dict],
|
||||
kg_config_settings: KGConfigSettings,
|
||||
) -> str:
|
||||
"""
|
||||
Creates a KGClassificationContent object from a list of chunks.
|
||||
"""
|
||||
|
||||
source_type = first_num_classification_chunks[0]["fields"]["source_type"]
|
||||
|
||||
if source_type.lower() in [call_type.value.lower() for call_type in OnyxCallTypes]:
|
||||
classification_content = _get_classification_content_from_call_chunks(
|
||||
first_num_classification_chunks,
|
||||
kg_config_settings,
|
||||
)
|
||||
|
||||
else:
|
||||
classification_content = (
|
||||
first_num_classification_chunks[0]["fields"]["title"]
|
||||
+ "\n"
|
||||
+ "\n".join(
|
||||
[
|
||||
chunk_content["fields"]["content"]
|
||||
for chunk_content in first_num_classification_chunks
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return classification_content
|
||||
@@ -40,6 +40,8 @@ from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE
|
||||
from onyx.configs.app_configs import SYSTEM_RECURSION_LIMIT
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
@@ -208,6 +210,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
)
|
||||
SqlEngine.get_engine()
|
||||
|
||||
SqlEngine.init_readonly_engine(
|
||||
pool_size=POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE,
|
||||
max_overflow=POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW,
|
||||
)
|
||||
|
||||
verify_auth = fetch_versioned_implementation(
|
||||
"onyx.auth.users", "verify_auth_setting"
|
||||
)
|
||||
|
||||
1317
backend/onyx/prompts/kg_prompts.py
Normal file
1317
backend/onyx/prompts/kg_prompts.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -81,6 +81,11 @@ class SearchToolOverrideKwargs(BaseModel):
|
||||
document_sources: list[DocumentSource] | None = None
|
||||
time_cutoff: datetime | None = None
|
||||
expanded_queries: QueryExpansions | None = None
|
||||
kg_entities: list[str] | None = None
|
||||
kg_relationships: list[str] | None = None
|
||||
kg_terms: list[str] | None = None
|
||||
kg_sources: list[str] | None = None
|
||||
kg_chunk_id_zero_only: bool | None = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -299,6 +299,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
document_sources = None
|
||||
time_cutoff = None
|
||||
expanded_queries = None
|
||||
kg_entities = None
|
||||
kg_relationships = None
|
||||
kg_terms = None
|
||||
kg_sources = None
|
||||
kg_chunk_id_zero_only = False
|
||||
if override_kwargs:
|
||||
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
|
||||
alternate_db_session = override_kwargs.alternate_db_session
|
||||
@@ -312,7 +317,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
document_sources = override_kwargs.document_sources
|
||||
time_cutoff = override_kwargs.time_cutoff
|
||||
expanded_queries = override_kwargs.expanded_queries
|
||||
|
||||
kg_entities = override_kwargs.kg_entities
|
||||
kg_relationships = override_kwargs.kg_relationships
|
||||
kg_terms = override_kwargs.kg_terms
|
||||
kg_sources = override_kwargs.kg_sources
|
||||
kg_chunk_id_zero_only = override_kwargs.kg_chunk_id_zero_only or False
|
||||
# Fast path for ordering-only search
|
||||
if ordering_only:
|
||||
yield from self._run_ordering_only_search(
|
||||
@@ -361,6 +370,19 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
# Overwrite time-cutoff should supercede existing time-cutoff, even if defined
|
||||
retrieval_options.filters.time_cutoff = time_cutoff
|
||||
|
||||
retrieval_options = retrieval_options or RetrievalDetails()
|
||||
retrieval_options.filters = retrieval_options.filters or BaseFilters()
|
||||
if kg_entities:
|
||||
retrieval_options.filters.kg_entities = kg_entities
|
||||
if kg_relationships:
|
||||
retrieval_options.filters.kg_relationships = kg_relationships
|
||||
if kg_terms:
|
||||
retrieval_options.filters.kg_terms = kg_terms
|
||||
if kg_sources:
|
||||
retrieval_options.filters.kg_sources = kg_sources
|
||||
if kg_chunk_id_zero_only:
|
||||
retrieval_options.filters.kg_chunk_id_zero_only = kg_chunk_id_zero_only
|
||||
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
|
||||
@@ -80,6 +80,7 @@ slack-sdk==3.20.2
|
||||
SQLAlchemy[mypy]==2.0.15
|
||||
starlette==0.46.1
|
||||
supervisor==4.2.5
|
||||
thefuzz==0.22.1
|
||||
tiktoken==0.7.0
|
||||
timeago==1.0.16
|
||||
transformers==4.49.0
|
||||
|
||||
@@ -155,7 +155,7 @@
|
||||
"Email: fiannellib46@marriott.com\nIsDeleted: false\nLastName: Iannelli\nIsEmailBounced: false\nFirstName: Felicio\nIsPriorityRecord: false\nCleanStatus: Pending"
|
||||
],
|
||||
"semantic_identifier": "Voonder",
|
||||
"metadata": {},
|
||||
"metadata": {"object_type": "Account"},
|
||||
"primary_owners": {"email": "hagen@danswer.ai", "first_name": "Hagen", "last_name": "oneill"},
|
||||
"secondary_owners": null,
|
||||
"title": null
|
||||
|
||||
@@ -51,15 +51,21 @@ def answer_instance(
|
||||
"onyx.chat.answer.fast_gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
|
||||
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config, mocker)
|
||||
|
||||
|
||||
def _answer_fixture_impl(
|
||||
mock_llm: LLM,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
mocker: MockerFixture,
|
||||
rerank_settings: RerankingDetails | None = None,
|
||||
) -> Answer:
|
||||
mock_db_session = Mock(spec=Session)
|
||||
mock_query = Mock()
|
||||
mock_query.all.return_value = []
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
return Answer(
|
||||
prompt_builder=AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
@@ -74,7 +80,7 @@ def _answer_fixture_impl(
|
||||
raw_user_query=QUERY,
|
||||
raw_user_uploaded_files=[],
|
||||
),
|
||||
db_session=Mock(spec=Session),
|
||||
db_session=mock_db_session,
|
||||
answer_style_config=answer_style_config,
|
||||
llm=mock_llm,
|
||||
fast_llm=mock_llm,
|
||||
@@ -404,7 +410,11 @@ def test_no_slow_reranking(
|
||||
)
|
||||
)
|
||||
answer_instance = _answer_fixture_impl(
|
||||
mock_llm, answer_style_config, prompt_config, rerank_settings=rerank_settings
|
||||
mock_llm,
|
||||
answer_style_config,
|
||||
prompt_config,
|
||||
mocker,
|
||||
rerank_settings=rerank_settings,
|
||||
)
|
||||
|
||||
assert (
|
||||
|
||||
@@ -43,6 +43,7 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
"onyx.chat.answer.fast_gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
|
||||
question = config["question"]
|
||||
skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"]
|
||||
|
||||
@@ -59,8 +60,14 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
mock_llm.stream = Mock()
|
||||
mock_llm.stream.return_value = [Mock()]
|
||||
|
||||
# Set up the mock database session
|
||||
mock_db_session = Mock(spec=Session)
|
||||
mock_query = Mock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.all.return_value = [] # Return empty list for KGConfig query
|
||||
|
||||
answer = Answer(
|
||||
db_session=Mock(spec=Session),
|
||||
db_session=mock_db_session,
|
||||
answer_style_config=answer_style_config,
|
||||
llm=mock_llm,
|
||||
fast_llm=mock_llm,
|
||||
|
||||
@@ -747,11 +747,17 @@ def test_salesforce_sqlite() -> None:
|
||||
sf_db.apply_schema()
|
||||
|
||||
_create_csv_with_example_data(sf_db)
|
||||
|
||||
_test_query(sf_db)
|
||||
|
||||
_test_upsert(sf_db)
|
||||
|
||||
_test_relationships(sf_db)
|
||||
|
||||
_test_account_with_children(sf_db)
|
||||
|
||||
_test_relationship_updates(sf_db)
|
||||
|
||||
_test_get_affected_parent_ids(sf_db)
|
||||
|
||||
sf_db.close()
|
||||
|
||||
@@ -183,6 +183,8 @@ services:
|
||||
- POSTGRES_HOST=relational_db
|
||||
- POSTGRES_USER=${POSTGRES_USER:-}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
|
||||
- DB_READONLY_USER=${DB_READONLY_USER:-}
|
||||
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-}
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-}
|
||||
@@ -383,6 +385,8 @@ services:
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
- DB_READONLY_USER=${DB_READONLY_USER:-}
|
||||
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
|
||||
@@ -148,6 +148,8 @@ services:
|
||||
- POSTGRES_HOST=relational_db
|
||||
- POSTGRES_USER=${POSTGRES_USER:-}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
|
||||
- DB_READONLY_USER=${DB_READONLY_USER:-}
|
||||
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-}
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
@@ -329,6 +331,8 @@ services:
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
- DB_READONLY_USER=${DB_READONLY_USER:-}
|
||||
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
|
||||
@@ -166,6 +166,8 @@ services:
|
||||
- POSTGRES_HOST=relational_db
|
||||
- POSTGRES_USER=${POSTGRES_USER:-}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
|
||||
- DB_READONLY_USER=${DB_READONLY_USER:-}
|
||||
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-}
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- VESPA_HOST=index
|
||||
@@ -356,6 +358,8 @@ services:
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
- DB_READONLY_USER=${DB_READONLY_USER:-}
|
||||
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
|
||||
@@ -153,6 +153,8 @@ services:
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
- DB_READONLY_USER=${DB_READONLY_USER:-}
|
||||
- DB_READONLY_PASSWORD=${DB_READONLY_PASSWORD:-}
|
||||
ports:
|
||||
- "5432"
|
||||
volumes:
|
||||
|
||||
@@ -55,3 +55,10 @@ SESSION_EXPIRE_TIME_SECONDS=604800
|
||||
# Default values here are what Postgres uses by default, feel free to change.
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=password
|
||||
|
||||
|
||||
# Default values here for the read-only user for the knowledge graph and other future read-only purposes.
|
||||
# Please change password!
|
||||
DB_READONLY_USER=db_readonly_user
|
||||
DB_READONLY_PASSWORD=password
|
||||
|
||||
|
||||
@@ -40,6 +40,16 @@ spec:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: postgres_password
|
||||
- name: DB_READONLY_USER
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: DB_READONLY_user
|
||||
- name: DB_READONLY_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: DB_READONLY_password
|
||||
args: ["-c", "max_connections=250"]
|
||||
ports:
|
||||
- containerPort: 5432
|
||||
|
||||
@@ -532,7 +532,7 @@ export function AssistantEditor({
|
||||
|
||||
// if disable_retrieval is set, set num_chunks to 0
|
||||
// to tell the backend to not fetch any documents
|
||||
const numChunks = searchToolEnabled ? values.num_chunks || 10 : 0;
|
||||
const numChunks = searchToolEnabled ? values.num_chunks || 25 : 0;
|
||||
const starterMessages = values.starter_messages
|
||||
.filter(
|
||||
(message: { message: string }) => message.message.trim() !== ""
|
||||
|
||||
Reference in New Issue
Block a user