mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
55 Commits
v0.23.1
...
batch_proc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
687122911d | ||
|
|
40953bd4fe | ||
|
|
a7acc07e79 | ||
|
|
b6e9e65bb8 | ||
|
|
20f2b9b2bb | ||
|
|
f731beca1f | ||
|
|
fe246aecbb | ||
|
|
50ad066712 | ||
|
|
870b59a1cc | ||
|
|
5c896cb0f7 | ||
|
|
184b30643d | ||
|
|
ae585fd84c | ||
|
|
61e8f371b9 | ||
|
|
33cc4be492 | ||
|
|
117c8c0d78 | ||
|
|
9bb8cdfff1 | ||
|
|
a52d0d29be | ||
|
|
f25e1e80f6 | ||
|
|
39fd6919ad | ||
|
|
7f0653d173 | ||
|
|
e9905a398b | ||
|
|
3ed44e8bae | ||
|
|
64158a5bdf | ||
|
|
afb2393596 | ||
|
|
d473c4e876 | ||
|
|
692058092f | ||
|
|
e88325aad6 | ||
|
|
7490250e91 | ||
|
|
e5369fcef8 | ||
|
|
b0f00953bc | ||
|
|
f6a75c86c6 | ||
|
|
ed9989282f | ||
|
|
e80a0f2716 | ||
|
|
909403a648 | ||
|
|
cd84b65011 | ||
|
|
413f21cec0 | ||
|
|
eb369384a7 | ||
|
|
0a24dbc52c | ||
|
|
a7ba0da8cc | ||
|
|
aaced6d551 | ||
|
|
4c230f92ea | ||
|
|
07d75b04d1 | ||
|
|
a8d10750c1 | ||
|
|
85e3ed57f1 | ||
|
|
e10cc8ccdb | ||
|
|
7018bc974b | ||
|
|
9c9075d71d | ||
|
|
338e084062 | ||
|
|
2f64031f5c | ||
|
|
abb74f2eaa | ||
|
|
a3e3d83b7e | ||
|
|
4dc88ca037 | ||
|
|
11e7e1c4d6 | ||
|
|
f2d74ce540 | ||
|
|
25389c5120 |
1
.github/CODEOWNERS
vendored
Normal file
1
.github/CODEOWNERS
vendored
Normal file
@@ -0,0 +1 @@
|
||||
* @onyx-dot-app/onyx-core-team
|
||||
94
.github/workflows/nightly-scan-licenses.yml
vendored
94
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -53,24 +53,90 @@ jobs:
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: ${{ always() }}
|
||||
if: always()
|
||||
run: echo "${{ steps.license_check_report.outputs.report }}"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Run Trivy vulnerability scanner in repo mode
|
||||
uses: aquasecurity/trivy-action@0.28.0
|
||||
with:
|
||||
scan-type: fs
|
||||
scanners: license
|
||||
format: table
|
||||
# format: sarif
|
||||
# output: trivy-results.sarif
|
||||
severity: HIGH,CRITICAL
|
||||
|
||||
# - name: Upload Trivy scan results to GitHub Security tab
|
||||
# uses: github/codeql-action/upload-sarif@v3
|
||||
# be careful enabling the sarif and upload as it may spam the security tab
|
||||
# with a huge amount of items. Work out the issues before enabling upload.
|
||||
# - name: Run Trivy vulnerability scanner in repo mode
|
||||
# if: always()
|
||||
# uses: aquasecurity/trivy-action@0.29.0
|
||||
# with:
|
||||
# sarif_file: trivy-results.sarif
|
||||
# scan-type: fs
|
||||
# scan-ref: .
|
||||
# scanners: license
|
||||
# format: table
|
||||
# severity: HIGH,CRITICAL
|
||||
# # format: sarif
|
||||
# # output: trivy-results.sarif
|
||||
#
|
||||
# # - name: Upload Trivy scan results to GitHub Security tab
|
||||
# # uses: github/codeql-action/upload-sarif@v3
|
||||
# # with:
|
||||
# # sarif_file: trivy-results.sarif
|
||||
|
||||
scan-trivy:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# Backend
|
||||
- name: Pull backend docker image
|
||||
run: docker pull onyxdotapp/onyx-backend:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on backend
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-backend:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
|
||||
|
||||
# Web server
|
||||
- name: Pull web server docker image
|
||||
run: docker pull onyxdotapp/onyx-web-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on web server
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-web-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
|
||||
# Model server
|
||||
- name: Pull model server docker image
|
||||
run: docker pull onyxdotapp/onyx-model-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-model-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
397
backend/alembic/versions/3bd4c84fe72f_improved_index.py
Normal file
397
backend/alembic/versions/3bd4c84fe72f_improved_index.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""improved index
|
||||
|
||||
Revision ID: 3bd4c84fe72f
|
||||
Revises: 8f43500ee275
|
||||
Create Date: 2025-02-26 13:07:56.217791
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import time
|
||||
from sqlalchemy import text
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3bd4c84fe72f"
|
||||
down_revision = "8f43500ee275"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# NOTE:
|
||||
# This migration addresses issues with the previous migration (8f43500ee275) which caused
|
||||
# an outage by creating an index without using CONCURRENTLY. This migration:
|
||||
#
|
||||
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
|
||||
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
|
||||
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
|
||||
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
|
||||
# (see: https://github.com/sqlalchemy/alembic/issues/277)
|
||||
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
|
||||
|
||||
|
||||
def upgrade():
|
||||
# --- PART 1: chat_message table ---
|
||||
# Step 1: Add nullable column (quick, minimal locking)
|
||||
# op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv")
|
||||
# op.execute("DROP TRIGGER IF EXISTS chat_message_tsv_trigger ON chat_message")
|
||||
# op.execute("DROP FUNCTION IF EXISTS update_chat_message_tsv()")
|
||||
# op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv")
|
||||
# # Drop chat_session tsv trigger if it exists
|
||||
# op.execute("DROP TRIGGER IF EXISTS chat_session_tsv_trigger ON chat_session")
|
||||
# op.execute("DROP FUNCTION IF EXISTS update_chat_session_tsv()")
|
||||
# op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS title_tsv")
|
||||
# raise Exception("Stop here")
|
||||
time.time()
|
||||
op.execute("ALTER TABLE chat_message ADD COLUMN IF NOT EXISTS message_tsv tsvector")
|
||||
|
||||
# Step 2: Create function and trigger for new/updated rows
|
||||
op.execute(
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION update_chat_message_tsv()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.message_tsv = to_tsvector('english', NEW.message);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql
|
||||
"""
|
||||
)
|
||||
|
||||
# Create trigger in a separate execute call
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TRIGGER chat_message_tsv_trigger
|
||||
BEFORE INSERT OR UPDATE ON chat_message
|
||||
FOR EACH ROW EXECUTE FUNCTION update_chat_message_tsv()
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 3: Update existing rows in batches using Python
|
||||
time.time()
|
||||
|
||||
# Get connection and count total rows
|
||||
connection = op.get_bind()
|
||||
total_count_result = connection.execute(
|
||||
text("SELECT COUNT(*) FROM chat_message")
|
||||
).scalar()
|
||||
total_count = total_count_result if total_count_result is not None else 0
|
||||
batch_size = 5000
|
||||
batches = 0
|
||||
|
||||
# Calculate total batches needed
|
||||
total_batches = (
|
||||
(total_count + batch_size - 1) // batch_size if total_count > 0 else 0
|
||||
)
|
||||
|
||||
# Process in batches - properly handling UUIDs by using OFFSET/LIMIT approach
|
||||
for batch_num in range(total_batches):
|
||||
offset = batch_num * batch_size
|
||||
|
||||
# Execute update for this batch using OFFSET/LIMIT which works with UUIDs
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET message_tsv = to_tsvector('english', message)
|
||||
WHERE id IN (
|
||||
SELECT id FROM chat_message
|
||||
WHERE message_tsv IS NULL
|
||||
ORDER BY id
|
||||
LIMIT :batch_size OFFSET :offset
|
||||
)
|
||||
"""
|
||||
).bindparams(batch_size=batch_size, offset=offset)
|
||||
)
|
||||
|
||||
# Commit each batch
|
||||
connection.execute(text("COMMIT"))
|
||||
# Start a new transaction
|
||||
connection.execute(text("BEGIN"))
|
||||
|
||||
batches += 1
|
||||
|
||||
# Final check for any remaining NULL values
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE chat_message SET message_tsv = to_tsvector('english', message)
|
||||
WHERE message_tsv IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Create GIN index concurrently
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
time.time()
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
|
||||
ON chat_message USING GIN (message_tsv)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# First drop the trigger as it won't be needed anymore
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DROP TRIGGER IF EXISTS chat_message_tsv_trigger ON chat_message;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DROP FUNCTION IF EXISTS update_chat_message_tsv();
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Add new generated column
|
||||
time.time()
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE chat_message
|
||||
ADD COLUMN message_tsv_gen tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
time.time()
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv_gen
|
||||
ON chat_message USING GIN (message_tsv_gen)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Drop old index and column
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(text("COMMIT"))
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE chat_message DROP COLUMN message_tsv;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Rename new column to old name
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE chat_message RENAME COLUMN message_tsv_gen TO message_tsv;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# --- PART 2: chat_session table ---
|
||||
|
||||
# Step 1: Add nullable column (quick, minimal locking)
|
||||
time.time()
|
||||
connection.execute(
|
||||
text(
|
||||
"ALTER TABLE chat_session ADD COLUMN IF NOT EXISTS description_tsv tsvector"
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Create function and trigger for new/updated rows - SPLIT INTO SEPARATE CALLS
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION update_chat_session_tsv()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.description_tsv = to_tsvector('english', COALESCE(NEW.description, ''));
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Create trigger in a separate execute call
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TRIGGER chat_session_tsv_trigger
|
||||
BEFORE INSERT OR UPDATE ON chat_session
|
||||
FOR EACH ROW EXECUTE FUNCTION update_chat_session_tsv()
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 3: Update existing rows in batches using Python
|
||||
time.time()
|
||||
|
||||
# Get the maximum ID to determine batch count
|
||||
# Cast id to text for MAX function since it's a UUID
|
||||
max_id_result = connection.execute(
|
||||
text("SELECT COALESCE(MAX(id::text), '0') FROM chat_session")
|
||||
).scalar()
|
||||
max_id_result if max_id_result is not None else "0"
|
||||
batch_size = 5000
|
||||
batches = 0
|
||||
|
||||
# Get all IDs ordered to process in batches
|
||||
rows = connection.execute(
|
||||
text("SELECT id FROM chat_session ORDER BY id")
|
||||
).fetchall()
|
||||
total_rows = len(rows)
|
||||
|
||||
# Process in batches
|
||||
for batch_num, batch_start in enumerate(range(0, total_rows, batch_size)):
|
||||
batch_end = min(batch_start + batch_size, total_rows)
|
||||
batch_ids = [row[0] for row in rows[batch_start:batch_end]]
|
||||
|
||||
if not batch_ids:
|
||||
continue
|
||||
|
||||
# Use IN clause instead of BETWEEN for UUIDs
|
||||
placeholders = ", ".join([f":id{i}" for i in range(len(batch_ids))])
|
||||
params = {f"id{i}": id_val for i, id_val in enumerate(batch_ids)}
|
||||
|
||||
# Execute update for this batch
|
||||
connection.execute(
|
||||
text(
|
||||
f"""
|
||||
UPDATE chat_session
|
||||
SET description_tsv = to_tsvector('english', COALESCE(description, ''))
|
||||
WHERE id IN ({placeholders})
|
||||
AND description_tsv IS NULL
|
||||
"""
|
||||
).bindparams(**params)
|
||||
)
|
||||
|
||||
# Commit each batch
|
||||
connection.execute(text("COMMIT"))
|
||||
# Start a new transaction
|
||||
connection.execute(text("BEGIN"))
|
||||
|
||||
batches += 1
|
||||
|
||||
# Final check for any remaining NULL values
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE chat_session SET description_tsv = to_tsvector('english', COALESCE(description, ''))
|
||||
WHERE description_tsv IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Create GIN index concurrently
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
time.time()
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
|
||||
ON chat_session USING GIN (description_tsv)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# After Final check for chat_session
|
||||
# First drop the trigger as it won't be needed anymore
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DROP TRIGGER IF EXISTS chat_session_tsv_trigger ON chat_session;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DROP FUNCTION IF EXISTS update_chat_session_tsv();
|
||||
"""
|
||||
)
|
||||
)
|
||||
# Add new generated column
|
||||
time.time()
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE chat_session
|
||||
ADD COLUMN description_tsv_gen tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', COALESCE(description, ''))) STORED;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Create new index on generated column
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
time.time()
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv_gen
|
||||
ON chat_session USING GIN (description_tsv_gen)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Drop old index and column
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(text("COMMIT"))
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE chat_session DROP COLUMN description_tsv;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Rename new column to old name
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE chat_session RENAME COLUMN description_tsv_gen TO description_tsv;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the indexes first (use CONCURRENTLY for dropping too)
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||
|
||||
# Then drop the columns
|
||||
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
|
||||
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
|
||||
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
@@ -0,0 +1,55 @@
|
||||
"""add background_reindex_enabled field
|
||||
|
||||
Revision ID: b7c2b63c4a03
|
||||
Revises: f11b408e39d3
|
||||
Create Date: 2024-03-26 12:34:56.789012
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7c2b63c4a03"
|
||||
down_revision = "f11b408e39d3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add background_reindex_enabled column with default value of True
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"background_reindex_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="true",
|
||||
),
|
||||
)
|
||||
|
||||
# Add embedding_precision column with default value of FLOAT
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"embedding_precision",
|
||||
sa.Enum(EmbeddingPrecision, native_enum=False),
|
||||
nullable=False,
|
||||
server_default=EmbeddingPrecision.FLOAT.name,
|
||||
),
|
||||
)
|
||||
|
||||
# Add reduced_dimension column with default value of None
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column("reduced_dimension", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the background_reindex_enabled column
|
||||
op.drop_column("search_settings", "background_reindex_enabled")
|
||||
op.drop_column("search_settings", "embedding_precision")
|
||||
op.drop_column("search_settings", "reduced_dimension")
|
||||
@@ -0,0 +1,36 @@
|
||||
"""force lowercase all users
|
||||
|
||||
Revision ID: f11b408e39d3
|
||||
Revises: 3bd4c84fe72f
|
||||
Create Date: 2025-02-26 17:04:55.683500
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f11b408e39d3"
|
||||
down_revision = "3bd4c84fe72f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) Convert all existing user emails to lowercase
|
||||
from alembic import op
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
"""
|
||||
)
|
||||
|
||||
# 2) Add a check constraint to ensure emails are always lowercase
|
||||
op.create_check_constraint("ensure_lowercase_email", "user", "email = LOWER(email)")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the check constraint
|
||||
from alembic import op
|
||||
|
||||
op.drop_constraint("ensure_lowercase_email", "user", type_="check")
|
||||
@@ -0,0 +1,42 @@
|
||||
"""lowercase multi-tenant user auth
|
||||
|
||||
Revision ID: 34e3630c7f32
|
||||
Revises: a4f6ee863c47
|
||||
Create Date: 2025-02-26 15:03:01.211894
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "34e3630c7f32"
|
||||
down_revision = "a4f6ee863c47"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) Convert all existing rows to lowercase
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE user_tenant_mapping
|
||||
SET email = LOWER(email)
|
||||
"""
|
||||
)
|
||||
# 2) Add a check constraint so that emails cannot be written in uppercase
|
||||
op.create_check_constraint(
|
||||
"ensure_lowercase_email",
|
||||
"user_tenant_mapping",
|
||||
"email = LOWER(email)",
|
||||
schema="public",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the check constraint
|
||||
op.drop_constraint(
|
||||
"ensure_lowercase_email",
|
||||
"user_tenant_mapping",
|
||||
schema="public",
|
||||
type_="check",
|
||||
)
|
||||
@@ -4,7 +4,8 @@ from ee.onyx.server.reporting.usage_export_generation import create_new_usage_re
|
||||
from onyx.background.celery.apps.primary import celery_app
|
||||
from onyx.background.task_utils import build_celery_task_wrapper
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.db.chat import delete_chat_sessions_older_than
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -18,7 +19,26 @@ logger = setup_logger()
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_sessions_older_than(retention_limit_days, db_session)
|
||||
old_chat_sessions = get_chat_sessions_older_than(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
|
||||
@@ -59,10 +59,14 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
|
||||
OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
|
||||
OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
|
||||
OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
|
||||
)
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
|
||||
)
|
||||
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
|
||||
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
|
||||
@@ -134,7 +134,9 @@ def fetch_chat_sessions_eagerly_by_time(
|
||||
limit: int | None = 500,
|
||||
initial_time: datetime | None = None,
|
||||
) -> list[ChatSession]:
|
||||
time_order: UnaryExpression = desc(ChatSession.time_created)
|
||||
"""Sorted by oldest to newest, then by message id"""
|
||||
|
||||
asc_time_order: UnaryExpression = asc(ChatSession.time_created)
|
||||
message_order: UnaryExpression = asc(ChatMessage.id)
|
||||
|
||||
filters: list[ColumnElement | BinaryExpression] = [
|
||||
@@ -147,8 +149,7 @@ def fetch_chat_sessions_eagerly_by_time(
|
||||
subquery = (
|
||||
db_session.query(ChatSession.id, ChatSession.time_created)
|
||||
.filter(*filters)
|
||||
.order_by(ChatSession.id, time_order)
|
||||
.distinct(ChatSession.id)
|
||||
.order_by(asc_time_order)
|
||||
.limit(limit)
|
||||
.subquery()
|
||||
)
|
||||
@@ -164,7 +165,7 @@ def fetch_chat_sessions_eagerly_by_time(
|
||||
ChatMessage.chat_message_feedbacks
|
||||
),
|
||||
)
|
||||
.order_by(time_order, message_order)
|
||||
.order_by(asc_time_order, message_order)
|
||||
)
|
||||
|
||||
chat_sessions = query.all()
|
||||
|
||||
@@ -16,13 +16,18 @@ from onyx.db.models import UsageReport
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
|
||||
|
||||
# Gets skeletons of all message
|
||||
# Gets skeletons of all messages in the given range
|
||||
def get_empty_chat_messages_entries__paginated(
|
||||
db_session: Session,
|
||||
period: tuple[datetime, datetime],
|
||||
limit: int | None = 500,
|
||||
initial_time: datetime | None = None,
|
||||
) -> tuple[Optional[datetime], list[ChatMessageSkeleton]]:
|
||||
"""Returns a tuple where:
|
||||
first element is the most recent timestamp out of the sessions iterated
|
||||
- this timestamp can be used to paginate forward in time
|
||||
second element is a list of messages belonging to all the sessions iterated
|
||||
"""
|
||||
chat_sessions = fetch_chat_sessions_eagerly_by_time(
|
||||
start=period[0],
|
||||
end=period[1],
|
||||
@@ -52,18 +57,17 @@ def get_empty_chat_messages_entries__paginated(
|
||||
if len(chat_sessions) == 0:
|
||||
return None, []
|
||||
|
||||
return chat_sessions[0].time_created, message_skeletons
|
||||
return chat_sessions[-1].time_created, message_skeletons
|
||||
|
||||
|
||||
def get_all_empty_chat_message_entries(
|
||||
db_session: Session,
|
||||
period: tuple[datetime, datetime],
|
||||
) -> Generator[list[ChatMessageSkeleton], None, None]:
|
||||
"""period is the range of time over which to fetch messages."""
|
||||
initial_time: Optional[datetime] = period[0]
|
||||
ind = 0
|
||||
while True:
|
||||
ind += 1
|
||||
|
||||
# iterate from oldest to newest
|
||||
time_created, message_skeletons = get_empty_chat_messages_entries__paginated(
|
||||
db_session,
|
||||
period,
|
||||
|
||||
@@ -424,7 +424,7 @@ def _validate_curator_status__no_commit(
|
||||
)
|
||||
|
||||
# if the user is a curator in any of their groups, set their role to CURATOR
|
||||
# otherwise, set their role to BASIC
|
||||
# otherwise, set their role to BASIC only if they were previously a CURATOR
|
||||
if curator_relationships:
|
||||
user.role = UserRole.CURATOR
|
||||
elif user.role == UserRole.CURATOR:
|
||||
@@ -631,7 +631,16 @@ def update_user_group(
|
||||
removed_users = db_session.scalars(
|
||||
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
|
||||
).unique()
|
||||
_validate_curator_status__no_commit(db_session, list(removed_users))
|
||||
|
||||
# Filter out admin and global curator users before validating curator status
|
||||
users_to_validate = [
|
||||
user
|
||||
for user in removed_users
|
||||
if user.role not in [UserRole.ADMIN, UserRole.GLOBAL_CURATOR]
|
||||
]
|
||||
|
||||
if users_to_validate:
|
||||
_validate_curator_status__no_commit(db_session, users_to_validate)
|
||||
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
@@ -9,12 +9,16 @@ from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GR
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -342,7 +346,8 @@ def _fetch_all_page_restrictions(
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -354,7 +359,11 @@ def confluence_doc_sync(
|
||||
confluence_connector = ConfluenceConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
confluence_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
provider = OnyxDBCredentialsProvider(
|
||||
get_current_tenant_id(), "confluence", cc_pair.credential_id
|
||||
)
|
||||
confluence_connector.set_credentials_provider(provider)
|
||||
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -61,13 +63,27 @@ def _build_group_member_email_map(
|
||||
|
||||
|
||||
def confluence_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
confluence_client = build_confluence_client(
|
||||
credentials=cc_pair.credential.credential_json,
|
||||
is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False),
|
||||
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
|
||||
)
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id)
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"]
|
||||
url = wiki_base.rstrip("/")
|
||||
|
||||
probe_kwargs = {
|
||||
"max_backoff_retries": 6,
|
||||
"max_backoff_seconds": 10,
|
||||
}
|
||||
|
||||
final_kwargs = {
|
||||
"max_backoff_retries": 10,
|
||||
"max_backoff_seconds": 60,
|
||||
}
|
||||
|
||||
confluence_client = OnyxConfluence(is_cloud, url, provider)
|
||||
confluence_client._probe_connection(**probe_kwargs)
|
||||
confluence_client._initialize_connection(**final_kwargs)
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(
|
||||
confluence_client=confluence_client,
|
||||
|
||||
@@ -32,7 +32,8 @@ def _get_slim_doc_generator(
|
||||
|
||||
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
|
||||
@@ -145,7 +145,8 @@ def _get_permissions_from_slim_doc(
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
|
||||
@@ -119,6 +119,7 @@ def _build_onyx_groups(
|
||||
|
||||
|
||||
def gdrive_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
# Initialize connector and build credential/service objects
|
||||
|
||||
@@ -123,7 +123,8 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
|
||||
@@ -28,6 +28,7 @@ DocSyncFuncType = Callable[
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[
|
||||
str,
|
||||
ConnectorCredentialPair,
|
||||
],
|
||||
list[ExternalUserGroup],
|
||||
|
||||
@@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
)
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
|
||||
from ee.onyx.server.oauth import router as oauth_router
|
||||
from ee.onyx.server.oauth.api import router as oauth_router
|
||||
from ee.onyx.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
)
|
||||
@@ -152,4 +152,8 @@ def get_application() -> FastAPI:
|
||||
# environment variable. Used to automate deployment for multiple environments.
|
||||
seed_db()
|
||||
|
||||
# for debugging discovered routes
|
||||
# for route in application.router.routes:
|
||||
# print(f"Path: {route.path}, Methods: {route.methods}")
|
||||
|
||||
return application
|
||||
|
||||
@@ -22,7 +22,7 @@ from onyx.onyxbot.slack.blocks import get_restate_blocks
|
||||
from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
|
||||
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import update_emote_react
|
||||
from onyx.utils.logger import OnyxLoggingAdapter
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -216,7 +216,7 @@ def _handle_standard_answers(
|
||||
all_blocks = restate_question_blocks + answer_blocks
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=message_info.channel_to_respond,
|
||||
receiver_ids=receiver_ids,
|
||||
@@ -231,6 +231,7 @@ def _handle_standard_answers(
|
||||
client=client,
|
||||
channel=message_info.channel_to_respond,
|
||||
thread_ts=slack_thread_id,
|
||||
receiver_ids=receiver_ids,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,629 +0,0 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/oauth")
|
||||
|
||||
|
||||
class SlackOAuth:
|
||||
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
|
||||
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/slack
|
||||
BOT_SCOPE = (
|
||||
"channels:history,"
|
||||
"channels:read,"
|
||||
"groups:history,"
|
||||
"groups:read,"
|
||||
"channels:join,"
|
||||
"im:history,"
|
||||
"users:read,"
|
||||
"users:read.email,"
|
||||
"usergroups:read"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = SlackOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class ConfluenceCloudOAuth:
|
||||
"""work in progress"""
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_user),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
session: str
|
||||
|
||||
if connector == DocumentSource.SLACK:
|
||||
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
# elif connector == DocumentSource.CONFLUENCE:
|
||||
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
# session = ConfluenceCloudOAuth.session_dump_json(
|
||||
# email=user.email, redirect_on_success=redirect_on_success
|
||||
# )
|
||||
# elif connector == DocumentSource.JIRA:
|
||||
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
if not oauth_url:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"The document source type {connector} does not have OAuth implemented",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
return JSONResponse(content={"url": oauth_url})
|
||||
|
||||
|
||||
@router.post("/connector/slack/callback")
|
||||
def handle_slack_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> JSONResponse:
|
||||
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Slack client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = SlackOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
SlackOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": SlackOAuth.CLIENT_ID,
|
||||
"client_secret": SlackOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": SlackOAuth.REDIRECT_URI,
|
||||
},
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
if not response_data.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed: {response_data.get('error')}",
|
||||
)
|
||||
|
||||
# Extract token and team information
|
||||
access_token: str = response_data.get("access_token")
|
||||
team_id: str = response_data.get("team", {}).get("id")
|
||||
authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={"slack_bot_token": access_token},
|
||||
admin_public=True,
|
||||
source=DocumentSource.SLACK,
|
||||
name="Slack OAuth",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Slack OAuth completed successfully.",
|
||||
"team_id": team_id,
|
||||
"authed_user_id": authed_user_id,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Work in progress
|
||||
# @router.post("/connector/confluence/callback")
|
||||
# def handle_confluence_oauth_callback(
|
||||
# code: str,
|
||||
# state: str,
|
||||
# user: User = Depends(current_user),
|
||||
# db_session: Session = Depends(get_session),
|
||||
# tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
# ) -> JSONResponse:
|
||||
# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
|
||||
# raise HTTPException(
|
||||
# status_code=500,
|
||||
# detail="Confluence client ID or client secret is not configured."
|
||||
# )
|
||||
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# # recover the state
|
||||
# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding)
|
||||
# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes
|
||||
|
||||
# # Convert bytes back to a UUID
|
||||
# oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
# oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
# result = r.get(r_key)
|
||||
# if not result:
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}"
|
||||
# )
|
||||
|
||||
# try:
|
||||
# session = ConfluenceCloudOAuth.parse_session(result)
|
||||
|
||||
# # Exchange the authorization code for an access token
|
||||
# response = requests.post(
|
||||
# ConfluenceCloudOAuth.TOKEN_URL,
|
||||
# headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
# data={
|
||||
# "client_id": ConfluenceCloudOAuth.CLIENT_ID,
|
||||
# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
|
||||
# "code": code,
|
||||
# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI,
|
||||
# },
|
||||
# )
|
||||
|
||||
# response_data = response.json()
|
||||
|
||||
# if not response_data.get("ok"):
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}"
|
||||
# )
|
||||
|
||||
# # Extract token and team information
|
||||
# access_token: str = response_data.get("access_token")
|
||||
# team_id: str = response_data.get("team", {}).get("id")
|
||||
# authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
# credential_info = CredentialBase(
|
||||
# credential_json={"slack_bot_token": access_token},
|
||||
# admin_public=True,
|
||||
# source=DocumentSource.CONFLUENCE,
|
||||
# name="Confluence OAuth",
|
||||
# )
|
||||
|
||||
# logger.info(f"Slack access token: {access_token}")
|
||||
|
||||
# credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
# logger.info(f"new_credential_id={credential.id}")
|
||||
# except Exception as e:
|
||||
# return JSONResponse(
|
||||
# status_code=500,
|
||||
# content={
|
||||
# "success": False,
|
||||
# "message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
# },
|
||||
# )
|
||||
# finally:
|
||||
# r.delete(r_key)
|
||||
|
||||
# # return the result
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "success": True,
|
||||
# "message": "Slack OAuth completed successfully.",
|
||||
# "team_id": team_id,
|
||||
# "authed_user_id": authed_user_id,
|
||||
# "redirect_on_success": session.redirect_on_success,
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
session: GoogleDriveOAuth.OAuthSession
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
91
backend/ee/onyx/server/oauth/api.py
Normal file
91
backend/ee/onyx/server/oauth/api.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth
|
||||
from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth
|
||||
from ee.onyx.server.oauth.slack import SlackOAuth
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_admin_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
|
||||
session: str | None = None
|
||||
if connector == DocumentSource.SLACK:
|
||||
if not DEV_MODE:
|
||||
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = SlackOAuth.generate_dev_oauth_url(oauth_state)
|
||||
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.CONFLUENCE:
|
||||
if not DEV_MODE:
|
||||
oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = ConfluenceCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
session = ConfluenceCloudOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
if not DEV_MODE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
if not oauth_url:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"The document source type {connector} does not have OAuth implemented",
|
||||
)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"The document source type {connector} failed to generate an OAuth session.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
return JSONResponse(content={"url": oauth_url})
|
||||
3
backend/ee/onyx/server/oauth/api_router.py
Normal file
3
backend/ee/onyx/server/oauth/api_router.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/oauth")
|
||||
361
backend/ee/onyx/server/oauth/confluence_cloud.py
Normal file
361
backend/ee/onyx/server/oauth/confluence_cloud.py
Normal file
@@ -0,0 +1,361 @@
|
||||
import base64
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.credentials import update_credential_json
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ConfluenceCloudOAuth:
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
refresh_token: str
|
||||
scope: str
|
||||
|
||||
class AccessibleResources(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
url: str
|
||||
scopes: list[str]
|
||||
avatarUrl: str
|
||||
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
TOKEN_URL = CONFLUENCE_OAUTH_TOKEN_URL
|
||||
|
||||
ACCESSIBLE_RESOURCE_URL = (
|
||||
"https://api.atlassian.com/oauth/token/accessible-resources"
|
||||
)
|
||||
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
# classic scope
|
||||
"read:confluence-space.summary%20"
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence%20"
|
||||
"search:confluence%20"
|
||||
# granular scope
|
||||
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
|
||||
"offline_access"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# https://developer.atlassian.com/cloud/jira/platform/oauth-2-3lo-apps/#1--direct-the-user-to-the-authorization-url-to-get-an-authorization-code
|
||||
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = ConfluenceCloudOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def generate_finalize_url(cls, credential_id: int) -> str:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/finalize?credential={credential_id}"
|
||||
|
||||
|
||||
@router.post("/connector/confluence/callback")
|
||||
def confluence_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Handles the backend logic for the frontend page that the user is redirected to
|
||||
after visiting the oauth authorization url."""
|
||||
|
||||
if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Confluence Cloud client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Confluence Cloud OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = ConfluenceCloudOAuth.parse_session(session_json)
|
||||
|
||||
if not DEV_MODE:
|
||||
redirect_uri = ConfluenceCloudOAuth.REDIRECT_URI
|
||||
else:
|
||||
redirect_uri = ConfluenceCloudOAuth.DEV_REDIRECT_URI
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
ConfluenceCloudOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": ConfluenceCloudOAuth.CLIENT_ID,
|
||||
"client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
token_response: ConfluenceCloudOAuth.TokenResponse | None = None
|
||||
|
||||
try:
|
||||
token_response = ConfluenceCloudOAuth.TokenResponse.model_validate_json(
|
||||
response.text
|
||||
)
|
||||
except Exception:
|
||||
raise RuntimeError(
|
||||
"Confluence Cloud OAuth failed during code/token exchange."
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=token_response.expires_in)
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={
|
||||
"confluence_access_token": token_response.access_token,
|
||||
"confluence_refresh_token": token_response.refresh_token,
|
||||
"created_at": now.isoformat(),
|
||||
"expires_at": expires_at.isoformat(),
|
||||
"expires_in": token_response.expires_in,
|
||||
"scope": token_response.scope,
|
||||
},
|
||||
admin_public=True,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
name="Confluence Cloud OAuth",
|
||||
)
|
||||
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Confluence Cloud OAuth completed successfully.",
|
||||
"finalize_url": ConfluenceCloudOAuth.generate_finalize_url(credential.id),
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/connector/confluence/accessible-resources")
|
||||
def confluence_oauth_accessible_resources(
|
||||
credential_id: int,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Atlassian's API is weird and does not supply us with enough info to be in a
|
||||
usable state after authorizing. All API's require a cloud id. We have to list
|
||||
the accessible resources/sites and let the user choose which site to use."""
|
||||
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if not credential:
|
||||
raise HTTPException(400, f"Credential {credential_id} not found.")
|
||||
|
||||
credential_dict = credential.credential_json
|
||||
access_token = credential_dict["confluence_access_token"]
|
||||
|
||||
try:
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.get(
|
||||
ConfluenceCloudOAuth.ACCESSIBLE_RESOURCE_URL,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
accessible_resources_data = response.json()
|
||||
|
||||
# Validate the list of AccessibleResources
|
||||
try:
|
||||
accessible_resources = [
|
||||
ConfluenceCloudOAuth.AccessibleResources(**resource)
|
||||
for resource in accessible_resources_data
|
||||
]
|
||||
except ValidationError as e:
|
||||
raise RuntimeError(f"Failed to parse accessible resources: {e}")
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred retrieving Confluence Cloud accessible resources: {str(e)}",
|
||||
},
|
||||
)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Confluence Cloud get accessible resources completed successfully.",
|
||||
"accessible_resources": [
|
||||
resource.model_dump() for resource in accessible_resources
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/connector/confluence/finalize")
|
||||
def confluence_oauth_finalize(
|
||||
credential_id: int,
|
||||
cloud_id: str,
|
||||
cloud_name: str,
|
||||
cloud_url: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Saves the info for the selected cloud site to the credential.
|
||||
This is the final step in the confluence oauth flow where after the traditional
|
||||
OAuth process, the user has to select a site to associate with the credentials.
|
||||
After this, the credential is usable."""
|
||||
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if not credential:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
|
||||
)
|
||||
|
||||
new_credential_json: dict[str, Any] = dict(credential.credential_json)
|
||||
new_credential_json["cloud_id"] = cloud_id
|
||||
new_credential_json["cloud_name"] = cloud_name
|
||||
new_credential_json["wiki_base"] = cloud_url
|
||||
|
||||
try:
|
||||
update_credential_json(credential_id, new_credential_json, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Confluence Cloud OAuth finalized successfully.",
|
||||
"redirect_url": f"{WEB_DOMAIN}/admin/connectors/confluence",
|
||||
}
|
||||
)
|
||||
229
backend/ee/onyx/server/oauth/google_drive.py
Normal file
229
backend/ee/onyx/server/oauth/google_drive.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.danswer.dev/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
if not DEV_MODE:
|
||||
redirect_uri = GoogleDriveOAuth.REDIRECT_URI
|
||||
else:
|
||||
redirect_uri = GoogleDriveOAuth.DEV_REDIRECT_URI
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"finalize_url": None,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
197
backend/ee/onyx/server/oauth/slack.py
Normal file
197
backend/ee/onyx/server/oauth/slack.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import base64
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
|
||||
|
||||
class SlackOAuth:
|
||||
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
|
||||
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
|
||||
|
||||
# SCOPE is per https://docs.danswer.dev/connectors/slack
|
||||
BOT_SCOPE = (
|
||||
"channels:history,"
|
||||
"channels:read,"
|
||||
"groups:history,"
|
||||
"groups:read,"
|
||||
"channels:join,"
|
||||
"im:history,"
|
||||
"users:read,"
|
||||
"users:read.email,"
|
||||
"usergroups:read"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = SlackOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/connector/slack/callback")
|
||||
def handle_slack_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Slack client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = SlackOAuth.parse_session(session_json)
|
||||
|
||||
if not DEV_MODE:
|
||||
redirect_uri = SlackOAuth.REDIRECT_URI
|
||||
else:
|
||||
redirect_uri = SlackOAuth.DEV_REDIRECT_URI
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
SlackOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": SlackOAuth.CLIENT_ID,
|
||||
"client_secret": SlackOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
if not response_data.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed: {response_data.get('error')}",
|
||||
)
|
||||
|
||||
# Extract token and team information
|
||||
access_token: str = response_data.get("access_token")
|
||||
team_id: str = response_data.get("team", {}).get("id")
|
||||
authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={"slack_bot_token": access_token},
|
||||
admin_public=True,
|
||||
source=DocumentSource.SLACK,
|
||||
name="Slack OAuth",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Slack OAuth completed successfully.",
|
||||
"finalize_url": None,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
"team_id": team_id,
|
||||
"authed_user_id": authed_user_id,
|
||||
}
|
||||
)
|
||||
@@ -2,6 +2,7 @@ import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -21,8 +22,10 @@ from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
@@ -35,6 +38,8 @@ from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid"
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history(
|
||||
db_session: Session,
|
||||
@@ -107,6 +112,17 @@ def get_user_chat_sessions(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
# we specifically don't allow this endpoint if "anonymized" since
|
||||
# this is a direct query on the user id
|
||||
if ONYX_QUERY_HISTORY_TYPE in [
|
||||
QueryHistoryType.DISABLED,
|
||||
QueryHistoryType.ANONYMIZED,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Per user query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id, deleted=False, db_session=db_session, limit=0
|
||||
@@ -122,6 +138,7 @@ def get_user_chat_sessions(
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
@@ -141,6 +158,12 @@ def get_chat_session_history(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
page_of_chat_sessions = get_page_of_chat_sessions(
|
||||
page_num=page_num,
|
||||
page_size=page_size,
|
||||
@@ -157,11 +180,16 @@ def get_chat_session_history(
|
||||
feedback_filter=feedback_type,
|
||||
)
|
||||
|
||||
minimal_chat_sessions: list[ChatSessionMinimal] = []
|
||||
|
||||
for chat_session in page_of_chat_sessions:
|
||||
minimal_chat_session = ChatSessionMinimal.from_chat_session(chat_session)
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
minimal_chat_session.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
minimal_chat_sessions.append(minimal_chat_session)
|
||||
|
||||
return PaginatedReturn(
|
||||
items=[
|
||||
ChatSessionMinimal.from_chat_session(chat_session)
|
||||
for chat_session in page_of_chat_sessions
|
||||
],
|
||||
items=minimal_chat_sessions,
|
||||
total_items=total_filtered_chat_sessions_count,
|
||||
)
|
||||
|
||||
@@ -172,6 +200,12 @@ def get_chat_session_admin(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionSnapshot:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -193,6 +227,9 @@ def get_chat_session_admin(
|
||||
f"Could not create snapshot for chat session with id '{chat_session_id}'",
|
||||
)
|
||||
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
return snapshot
|
||||
|
||||
|
||||
@@ -203,6 +240,12 @@ def get_query_history_as_csv(
|
||||
end: datetime | None = None,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
complete_chat_session_history = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
|
||||
@@ -213,6 +256,9 @@ def get_query_history_as_csv(
|
||||
|
||||
question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
|
||||
for chat_session_snapshot in complete_chat_session_history:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
question_answer_pairs.extend(
|
||||
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -41,7 +42,9 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
return response.json()
|
||||
|
||||
|
||||
def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
||||
def fetch_billing_information(
|
||||
tenant_id: str,
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
@@ -52,8 +55,19 @@ def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
billing_info = BillingInformation(**response.json())
|
||||
return billing_info
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# Check if the response indicates no subscription
|
||||
if (
|
||||
isinstance(response_data, dict)
|
||||
and "subscribed" in response_data
|
||||
and not response_data["subscribed"]
|
||||
):
|
||||
return SubscriptionStatusResponse(**response_data)
|
||||
|
||||
# Otherwise, parse as BillingInformation
|
||||
return BillingInformation(**response_data)
|
||||
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
|
||||
@@ -104,14 +104,14 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
status_code=409, detail="User already belongs to an organization"
|
||||
)
|
||||
|
||||
logger.info(f"Provisioning tenant: {tenant_id}")
|
||||
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
|
||||
token = None
|
||||
|
||||
try:
|
||||
if not create_schema_if_not_exists(tenant_id):
|
||||
logger.info(f"Created schema for tenant {tenant_id}")
|
||||
logger.debug(f"Created schema for tenant {tenant_id}")
|
||||
else:
|
||||
logger.info(f"Schema already exists for tenant {tenant_id}")
|
||||
logger.debug(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
@@ -200,25 +200,6 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
@@ -227,6 +208,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
display_model_names=["claude-3-5-sonnet-20241022"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
@@ -238,6 +220,26 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
if COHERE_DEFAULT_API_KEY:
|
||||
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
|
||||
@@ -6,7 +6,7 @@ MODEL_WARM_UP_STRING = "hi " * 512
|
||||
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
|
||||
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
|
||||
DEFAULT_VERTEX_MODEL = "text-embedding-004"
|
||||
DEFAULT_VERTEX_MODEL = "text-embedding-005"
|
||||
|
||||
|
||||
class EmbeddingModelTextType:
|
||||
|
||||
@@ -5,6 +5,7 @@ from types import TracebackType
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
import aioboto3 # type: ignore
|
||||
import httpx
|
||||
import openai
|
||||
import vertexai # type: ignore
|
||||
@@ -28,11 +29,13 @@ from model_server.constants import DEFAULT_VERTEX_MODEL
|
||||
from model_server.constants import DEFAULT_VOYAGE_MODEL
|
||||
from model_server.constants import EmbeddingModelTextType
|
||||
from model_server.constants import EmbeddingProvider
|
||||
from model_server.utils import pass_aws_key
|
||||
from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
|
||||
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
@@ -78,7 +81,7 @@ class CloudEmbedding:
|
||||
self._closed = False
|
||||
|
||||
async def _embed_openai(
|
||||
self, texts: list[str], model: str | None
|
||||
self, texts: list[str], model: str | None, reduced_dimension: int | None
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
@@ -91,7 +94,11 @@ class CloudEmbedding:
|
||||
final_embeddings: list[Embedding] = []
|
||||
try:
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = await client.embeddings.create(input=text_batch, model=model)
|
||||
response = await client.embeddings.create(
|
||||
input=text_batch,
|
||||
model=model,
|
||||
dimensions=reduced_dimension or openai.NOT_GIVEN,
|
||||
)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
@@ -178,17 +185,24 @@ class CloudEmbedding:
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
client = TextEmbeddingModel.from_pretrained(model)
|
||||
|
||||
embeddings = await client.get_embeddings_async(
|
||||
[
|
||||
TextEmbeddingInput(
|
||||
text,
|
||||
embedding_type,
|
||||
)
|
||||
for text in texts
|
||||
],
|
||||
auto_truncate=True, # This is the default
|
||||
)
|
||||
return [embedding.values for embedding in embeddings]
|
||||
inputs = [TextEmbeddingInput(text, embedding_type) for text in texts]
|
||||
|
||||
# Split into batches of 25 texts
|
||||
max_texts_per_batch = VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
|
||||
batches = [
|
||||
inputs[i : i + max_texts_per_batch]
|
||||
for i in range(0, len(inputs), max_texts_per_batch)
|
||||
]
|
||||
|
||||
# Dispatch all embedding calls asynchronously at once
|
||||
tasks = [
|
||||
client.get_embeddings_async(batch, auto_truncate=True) for batch in batches
|
||||
]
|
||||
|
||||
# Wait for all tasks to complete in parallel
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
return [embedding.values for batch in results for embedding in batch]
|
||||
|
||||
async def _embed_litellm_proxy(
|
||||
self, texts: list[str], model_name: str | None
|
||||
@@ -223,9 +237,10 @@ class CloudEmbedding:
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
reduced_dimension: int | None = None,
|
||||
) -> list[Embedding]:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return await self._embed_openai(texts, model_name)
|
||||
return await self._embed_openai(texts, model_name, reduced_dimension)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
@@ -326,6 +341,7 @@ async def embed_text(
|
||||
prefix: str | None,
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
reduced_dimension: int | None,
|
||||
gpu_type: str = "UNKNOWN",
|
||||
) -> list[Embedding]:
|
||||
if not all(texts):
|
||||
@@ -369,6 +385,7 @@ async def embed_text(
|
||||
model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
text_type=text_type,
|
||||
reduced_dimension=reduced_dimension,
|
||||
)
|
||||
|
||||
if any(embedding is None for embedding in embeddings):
|
||||
@@ -440,7 +457,7 @@ async def local_rerank(query: str, docs: list[str], model_name: str) -> list[flo
|
||||
)
|
||||
|
||||
|
||||
async def cohere_rerank(
|
||||
async def cohere_rerank_api(
|
||||
query: str, docs: list[str], model_name: str, api_key: str
|
||||
) -> list[float]:
|
||||
cohere_client = CohereAsyncClient(api_key=api_key)
|
||||
@@ -450,6 +467,45 @@ async def cohere_rerank(
|
||||
return [result.relevance_score for result in sorted_results]
|
||||
|
||||
|
||||
async def cohere_rerank_aws(
|
||||
query: str,
|
||||
docs: list[str],
|
||||
model_name: str,
|
||||
region_name: str,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
) -> list[float]:
|
||||
session = aioboto3.Session(
|
||||
aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
|
||||
)
|
||||
async with session.client(
|
||||
"bedrock-runtime", region_name=region_name
|
||||
) as bedrock_client:
|
||||
body = json.dumps(
|
||||
{
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
"api_version": 2,
|
||||
}
|
||||
)
|
||||
# Invoke the Bedrock model asynchronously
|
||||
response = await bedrock_client.invoke_model(
|
||||
modelId=model_name,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
body=body,
|
||||
)
|
||||
|
||||
# Read the response asynchronously
|
||||
response_body = json.loads(await response["body"].read())
|
||||
|
||||
# Extract and sort the results
|
||||
results = response_body.get("results", [])
|
||||
sorted_results = sorted(results, key=lambda item: item["index"])
|
||||
|
||||
return [result["relevance_score"] for result in sorted_results]
|
||||
|
||||
|
||||
async def litellm_rerank(
|
||||
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
|
||||
) -> list[float]:
|
||||
@@ -508,6 +564,7 @@ async def process_embed_request(
|
||||
text_type=embed_request.text_type,
|
||||
api_url=embed_request.api_url,
|
||||
api_version=embed_request.api_version,
|
||||
reduced_dimension=embed_request.reduced_dimension,
|
||||
prefix=prefix,
|
||||
gpu_type=gpu_type,
|
||||
)
|
||||
@@ -564,15 +621,32 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
|
||||
elif rerank_request.provider_type == RerankerProvider.COHERE:
|
||||
if rerank_request.api_key is None:
|
||||
raise RuntimeError("Cohere Rerank Requires an API Key")
|
||||
sim_scores = await cohere_rerank(
|
||||
sim_scores = await cohere_rerank_api(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
api_key=rerank_request.api_key,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
|
||||
elif rerank_request.provider_type == RerankerProvider.BEDROCK:
|
||||
if rerank_request.api_key is None:
|
||||
raise RuntimeError("Bedrock Rerank Requires an API Key")
|
||||
aws_access_key_id, aws_secret_access_key, aws_region = pass_aws_key(
|
||||
rerank_request.api_key
|
||||
)
|
||||
sim_scores = await cohere_rerank_aws(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
region_name=aws_region,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during reranking process:\n{str(e)}")
|
||||
raise HTTPException(
|
||||
|
||||
@@ -70,3 +70,32 @@ def get_gpu_type() -> str:
|
||||
return GPUStatus.MAC_MPS
|
||||
|
||||
return GPUStatus.NONE
|
||||
|
||||
|
||||
def pass_aws_key(api_key: str) -> tuple[str, str, str]:
|
||||
"""Parse AWS API key string into components.
|
||||
|
||||
Args:
|
||||
api_key: String in format 'aws_ACCESSKEY_SECRETKEY_REGION'
|
||||
|
||||
Returns:
|
||||
Tuple of (access_key, secret_key, region)
|
||||
|
||||
Raises:
|
||||
ValueError: If key format is invalid
|
||||
"""
|
||||
if not api_key.startswith("aws"):
|
||||
raise ValueError("API key must start with 'aws' prefix")
|
||||
|
||||
parts = api_key.split("_")
|
||||
if len(parts) != 4:
|
||||
raise ValueError(
|
||||
f"API key must be in format 'aws_ACCESSKEY_SECRETKEY_REGION', got {len(parts) - 1} parts"
|
||||
"this is an onyx specific format for formatting the aws secrets for bedrock"
|
||||
)
|
||||
|
||||
try:
|
||||
_, aws_access_key_id, aws_secret_access_key, aws_region = parts
|
||||
return aws_access_key_id, aws_secret_access_key, aws_region
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse AWS key components: {str(e)}")
|
||||
|
||||
@@ -98,8 +98,16 @@ def choose_tool(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=built_prompt,
|
||||
tools=[tool.tool_definition() for tool in tools] or None,
|
||||
tool_choice=("required" if tools and force_use_tool.force_use else None),
|
||||
tools=(
|
||||
[tool.tool_definition() for tool in tools] or None
|
||||
if using_tool_calling_llm
|
||||
else None
|
||||
),
|
||||
tool_choice=(
|
||||
"required"
|
||||
if tools and force_use_tool.force_use and using_tool_calling_llm
|
||||
else None
|
||||
),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
|
||||
@@ -411,7 +411,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
user: User
|
||||
user: User | None = None
|
||||
|
||||
try:
|
||||
# Attempt to get user by OAuth account
|
||||
@@ -420,15 +420,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
except exceptions.UserNotExists:
|
||||
try:
|
||||
# Attempt to get user by email
|
||||
user = await self.get_by_email(account_email)
|
||||
user = await self.user_db.get_by_email(account_email)
|
||||
if not associate_by_email:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
# Make sure user is not None before adding OAuth account
|
||||
if user is not None:
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
else:
|
||||
# This shouldn't happen since get_by_email would raise UserNotExists
|
||||
# but adding as a safeguard
|
||||
raise exceptions.UserNotExists()
|
||||
|
||||
# If user not found by OAuth account or email, create a new user
|
||||
except exceptions.UserNotExists:
|
||||
password = self.password_helper.generate()
|
||||
user_dict = {
|
||||
@@ -439,26 +444,36 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
# Add OAuth account only if user creation was successful
|
||||
if user is not None:
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to create user account"
|
||||
)
|
||||
|
||||
else:
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
oauth_account_dict,
|
||||
)
|
||||
# User exists, update OAuth account if needed
|
||||
if user is not None: # Add explicit check
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
oauth_account_dict,
|
||||
)
|
||||
|
||||
# Ensure user is not None before proceeding
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to authenticate or create user"
|
||||
)
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
@@ -508,6 +523,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
user_count = await get_user_count()
|
||||
logger.debug(f"Current tenant user count: {user_count}")
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if user_count == 1:
|
||||
@@ -529,7 +545,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
logger.notice(f"User {user.id} has registered.")
|
||||
logger.debug(f"User {user.id} has registered.")
|
||||
optional_telemetry(
|
||||
record_type=RecordType.SIGN_UP,
|
||||
data={"action": "create"},
|
||||
|
||||
@@ -423,7 +423,7 @@ def connector_external_group_sync_generator_task(
|
||||
)
|
||||
external_user_groups: list[ExternalUserGroup] = []
|
||||
try:
|
||||
external_user_groups = ext_group_sync_func(cc_pair)
|
||||
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
|
||||
except ConnectorValidationError as e:
|
||||
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
update_connector_credential_pair(
|
||||
|
||||
@@ -23,9 +23,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.indexing.utils import _should_index
|
||||
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.indexing.utils import should_index
|
||||
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
|
||||
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
|
||||
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
|
||||
@@ -61,7 +61,7 @@ from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.swap_index import check_index_swap
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
@@ -406,7 +406,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
|
||||
# check for search settings swap
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
old_search_settings = check_and_perform_index_swap(db_session=db_session)
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
@@ -439,6 +439,15 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
search_settings_list = get_active_search_settings_list(db_session)
|
||||
for search_settings_instance in search_settings_list:
|
||||
# skip non-live search settings that don't have background reindex enabled
|
||||
# those should just auto-change to live shortly after creation without
|
||||
# requiring any indexing till that point
|
||||
if (
|
||||
not search_settings_instance.status.is_current()
|
||||
and not search_settings_instance.background_reindex_enabled
|
||||
):
|
||||
continue
|
||||
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
)
|
||||
@@ -456,23 +465,18 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
search_settings_primary = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
search_settings_primary = True
|
||||
|
||||
if not _should_index(
|
||||
if not should_index(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
search_settings_primary=search_settings_primary,
|
||||
secondary_index_building=len(search_settings_list) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
reindex = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
# the indexing trigger is only checked and cleared with the primary search settings
|
||||
if search_settings_instance.status.is_current():
|
||||
# the indexing trigger is only checked and cleared with the current search settings
|
||||
if cc_pair.indexing_trigger is not None:
|
||||
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
|
||||
reindex = True
|
||||
|
||||
@@ -346,11 +346,10 @@ def validate_indexing_fences(
|
||||
return
|
||||
|
||||
|
||||
def _should_index(
|
||||
def should_index(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
search_settings_primary: bool,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
@@ -415,9 +414,9 @@ def _should_index(
|
||||
):
|
||||
return False
|
||||
|
||||
if search_settings_primary:
|
||||
if search_settings_instance.status.is_current():
|
||||
if cc_pair.indexing_trigger is not None:
|
||||
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
|
||||
# if a manual indexing trigger is on the cc pair, honor it for live search settings
|
||||
return True
|
||||
|
||||
# if no attempt has ever occurred, we should index regardless of refresh_freq
|
||||
|
||||
@@ -298,6 +298,7 @@ def cloud_beat_task_generator(
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
tenant_ids: list[str] = []
|
||||
num_processed_tenants = 0
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
@@ -325,6 +326,8 @@ def cloud_beat_task_generator(
|
||||
expires=expires,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
num_processed_tenants += 1
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -344,6 +347,7 @@ def cloud_beat_task_generator(
|
||||
task_logger.info(
|
||||
f"cloud_beat_task_generator finished: "
|
||||
f"task={task_name} "
|
||||
f"num_processed_tenants={num_processed_tenants} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
|
||||
@@ -11,10 +11,27 @@ def emit_background_error(
|
||||
"""Currently just saves a row in the background_errors table.
|
||||
|
||||
In the future, could create notifications based on the severity."""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
error_message = ""
|
||||
|
||||
# try to write to the db, but handle IntegrityError specifically
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_background_error(db_session, message, cc_pair_id)
|
||||
except IntegrityError as e:
|
||||
# Log an error if the cc_pair_id was deleted or any other exception occurs
|
||||
error_message = f"Failed to create background error: {str(e)}. Original message: {message}"
|
||||
except IntegrityError as e:
|
||||
# Log an error if the cc_pair_id was deleted or any other exception occurs
|
||||
error_message = (
|
||||
f"Failed to create background error: {str(e)}. Original message: {message}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not error_message:
|
||||
return
|
||||
|
||||
# if we get here from an IntegrityError, try to write the error message to the db
|
||||
# we need a new session because the first session is now invalid
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_background_error(db_session, error_message, None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import Optional
|
||||
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.setup import setup_logger
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
@@ -22,6 +22,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
@@ -92,11 +93,17 @@ def _get_connector_runner(
|
||||
if not INTEGRATION_TESTS_MODE:
|
||||
runnable_connector.validate_connector_settings()
|
||||
|
||||
except UnexpectedValidationError as e:
|
||||
logger.exception(
|
||||
"Unable to instantiate connector due to an unexpected temporary issue."
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
|
||||
logger.exception("Unable to instantiate connector. Pausing until fixed.")
|
||||
# since we failed to even instantiate the connector, we pause the CCPair since
|
||||
# it will never succeed. Sometimes there are cases where the connector will
|
||||
# it will never succeed
|
||||
|
||||
# Sometimes there are cases where the connector will
|
||||
# intermittently fail to initialize in which case we should pass in
|
||||
# leave_connector_active=True to allow it to continue.
|
||||
# For example, if there is nightly maintenance on a Confluence Server instance,
|
||||
|
||||
@@ -756,6 +756,7 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
|
||||
answer = Answer(
|
||||
prompt_builder=prompt_builder,
|
||||
is_connected=is_connected,
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import cast
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
|
||||
#####
|
||||
@@ -29,6 +30,9 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
) # 1 day
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
ONYX_QUERY_HISTORY_TYPE = QueryHistoryType(
|
||||
(os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower()
|
||||
)
|
||||
|
||||
#####
|
||||
# Web Configs
|
||||
@@ -636,3 +640,6 @@ TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
|
||||
MOCK_LLM_RESPONSE = (
|
||||
os.environ.get("MOCK_LLM_RESPONSE") if os.environ.get("MOCK_LLM_RESPONSE") else None
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
|
||||
|
||||
@@ -213,6 +213,12 @@ class AuthType(str, Enum):
|
||||
CLOUD = "cloud"
|
||||
|
||||
|
||||
class QueryHistoryType(str, Enum):
|
||||
DISABLED = "disabled"
|
||||
ANONYMIZED = "anonymized"
|
||||
NORMAL = "normal"
|
||||
|
||||
|
||||
# Special characters for password validation
|
||||
PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
|
||||
|
||||
38
backend/onyx/configs/llm_configs.py
Normal file
38
backend/onyx/configs/llm_configs.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from onyx.configs.app_configs import DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB
|
||||
from onyx.server.settings.store import load_settings
|
||||
|
||||
|
||||
def get_image_extraction_and_analysis_enabled() -> bool:
|
||||
"""Get image extraction and analysis enabled setting from workspace settings or fallback to False"""
|
||||
try:
|
||||
settings = load_settings()
|
||||
if settings.image_extraction_and_analysis_enabled is not None:
|
||||
return settings.image_extraction_and_analysis_enabled
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_search_time_image_analysis_enabled() -> bool:
|
||||
"""Get search time image analysis enabled setting from workspace settings or fallback to False"""
|
||||
try:
|
||||
settings = load_settings()
|
||||
if settings.search_time_image_analysis_enabled is not None:
|
||||
return settings.search_time_image_analysis_enabled
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_image_analysis_max_size_mb() -> int:
|
||||
"""Get image analysis max size MB setting from workspace settings or fallback to environment variable"""
|
||||
try:
|
||||
settings = load_settings()
|
||||
if settings.image_analysis_max_size_mb is not None:
|
||||
return settings.image_analysis_max_size_mb
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB
|
||||
@@ -200,7 +200,6 @@ class AirtableConnector(LoadConnector):
|
||||
return attachment_response.content
|
||||
|
||||
logger.error(f"Failed to refresh attachment for {filename}")
|
||||
|
||||
raise
|
||||
|
||||
attachment_content = get_attachment_with_retry(url, record_id)
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -310,7 +310,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
# Catch-all for anything not captured by the above
|
||||
# Since we are unsure of the error and it may not disable the connector,
|
||||
# raise an unexpected error (does not disable connector)
|
||||
raise UnexpectedError(
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected error during blob storage settings validation: {e}"
|
||||
)
|
||||
|
||||
|
||||
@@ -11,17 +11,19 @@ from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import extract_text_from_confluence_html
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import attachment_to_content
|
||||
from onyx.connectors.confluence.utils import build_confluence_document_id
|
||||
from onyx.connectors.confluence.utils import convert_attachment_to_content
|
||||
from onyx.connectors.confluence.utils import datetime_from_string
|
||||
from onyx.connectors.confluence.utils import extract_text_from_confluence_html
|
||||
from onyx.connectors.confluence.utils import process_attachment
|
||||
from onyx.connectors.confluence.utils import validate_attachment_filetype
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
@@ -33,28 +35,26 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Potential Improvements
|
||||
# 1. Include attachments, etc
|
||||
# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
|
||||
|
||||
# 1. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
|
||||
_COMMENT_EXPANSION_FIELDS = ["body.storage.value"]
|
||||
_PAGE_EXPANSION_FIELDS = [
|
||||
"body.storage.value",
|
||||
"version",
|
||||
"space",
|
||||
"metadata.labels",
|
||||
"history.lastUpdated",
|
||||
]
|
||||
_ATTACHMENT_EXPANSION_FIELDS = [
|
||||
"version",
|
||||
"space",
|
||||
"metadata.labels",
|
||||
]
|
||||
|
||||
_RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
"space",
|
||||
"restrictions.read.restrictions.user",
|
||||
@@ -83,7 +83,13 @@ _FULL_EXTENSION_FILTER_STRING = "".join(
|
||||
)
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class ConfluenceConnector(
|
||||
LoadConnector,
|
||||
PollConnector,
|
||||
SlimConnector,
|
||||
CredentialsConnector,
|
||||
VisionEnabledConnector,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
wiki_base: str,
|
||||
@@ -100,14 +106,24 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
|
||||
) -> None:
|
||||
self.wiki_base = wiki_base
|
||||
self.is_cloud = is_cloud
|
||||
self.space = space
|
||||
self.page_id = page_id
|
||||
self.index_recursively = index_recursively
|
||||
self.cql_query = cql_query
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.labels_to_skip = labels_to_skip
|
||||
self.timezone_offset = timezone_offset
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self.is_cloud = is_cloud
|
||||
self._fetched_titles: set[str] = set()
|
||||
|
||||
# Initialize vision LLM using the mixin
|
||||
self.initialize_vision_llm()
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
|
||||
"""
|
||||
If nothing is provided, we default to fetching all pages
|
||||
Only one or none of the following options should be specified so
|
||||
@@ -137,6 +153,17 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
|
||||
self.credentials_provider: CredentialsProviderInterface | None = None
|
||||
|
||||
self.probe_kwargs = {
|
||||
"max_backoff_retries": 6,
|
||||
"max_backoff_seconds": 10,
|
||||
}
|
||||
|
||||
self.final_kwargs = {
|
||||
"max_backoff_retries": 10,
|
||||
"max_backoff_seconds": 60,
|
||||
}
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
@@ -144,15 +171,22 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
return self._confluence_client
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
|
||||
# for a list of other hidden constructor args
|
||||
self._confluence_client = build_confluence_client(
|
||||
credentials=credentials,
|
||||
is_cloud=self.is_cloud,
|
||||
wiki_base=self.wiki_base,
|
||||
def set_credentials_provider(
|
||||
self, credentials_provider: CredentialsProviderInterface
|
||||
) -> None:
|
||||
self.credentials_provider = credentials_provider
|
||||
|
||||
# raises exception if there's a problem
|
||||
confluence_client = OnyxConfluence(
|
||||
self.is_cloud, self.wiki_base, credentials_provider
|
||||
)
|
||||
return None
|
||||
confluence_client._probe_connection(**self.probe_kwargs)
|
||||
confluence_client._initialize_connection(**self.final_kwargs)
|
||||
|
||||
self._confluence_client = confluence_client
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
raise NotImplementedError("Use set_credentials_provider with this connector.")
|
||||
|
||||
def _construct_page_query(
|
||||
self,
|
||||
@@ -160,7 +194,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> str:
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
|
||||
# Add time filters
|
||||
if start:
|
||||
formatted_start_time = datetime.fromtimestamp(
|
||||
@@ -172,7 +205,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
page_query += f" and lastmodified <= '{formatted_end_time}'"
|
||||
|
||||
return page_query
|
||||
|
||||
def _construct_attachment_query(self, confluence_page_id: str) -> str:
|
||||
@@ -183,11 +215,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
def _get_comment_string_for_page_id(self, page_id: str) -> str:
|
||||
comment_string = ""
|
||||
|
||||
comment_cql = f"type=comment and container='{page_id}'"
|
||||
comment_cql += self.cql_label_filter
|
||||
|
||||
expand = ",".join(_COMMENT_EXPANSION_FIELDS)
|
||||
|
||||
for comment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=comment_cql,
|
||||
expand=expand,
|
||||
@@ -198,116 +229,177 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
confluence_object=comment,
|
||||
fetched_titles=set(),
|
||||
)
|
||||
|
||||
return comment_string
|
||||
|
||||
def _convert_object_to_document(
|
||||
self, confluence_object: dict[str, Any]
|
||||
) -> Document | None:
|
||||
def _convert_page_to_document(self, page: dict[str, Any]) -> Document | None:
|
||||
"""
|
||||
Takes in a confluence object, extracts all metadata, and converts it into a document.
|
||||
If its a page, it extracts the text, adds the comments for the document text.
|
||||
If its an attachment, it just downloads the attachment and converts that into a document.
|
||||
Converts a Confluence page to a Document object.
|
||||
Includes the page content, comments, and attachments.
|
||||
"""
|
||||
# The url and the id are the same
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
try:
|
||||
# Extract basic page information
|
||||
page_id = page["id"]
|
||||
page_title = page["title"]
|
||||
page_url = f"{self.wiki_base}/wiki{page['_links']['webui']}"
|
||||
|
||||
object_text = None
|
||||
# Extract text from page
|
||||
if confluence_object["type"] == "page":
|
||||
object_text = extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=confluence_object,
|
||||
fetched_titles={confluence_object.get("title", "")},
|
||||
)
|
||||
# Add comments to text
|
||||
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
|
||||
elif confluence_object["type"] == "attachment":
|
||||
object_text = attachment_to_content(
|
||||
confluence_client=self.confluence_client, attachment=confluence_object
|
||||
# Get the page content
|
||||
page_content = extract_text_from_confluence_html(
|
||||
self.confluence_client, page, self._fetched_titles
|
||||
)
|
||||
|
||||
if object_text is None:
|
||||
# This only happens for attachments that are not parseable
|
||||
# Create the main section for the page content
|
||||
sections = [Section(text=page_content, link=page_url)]
|
||||
|
||||
# Process comments if available
|
||||
comment_text = self._get_comment_string_for_page_id(page_id)
|
||||
if comment_text:
|
||||
sections.append(Section(text=comment_text, link=f"{page_url}#comments"))
|
||||
|
||||
# Process attachments
|
||||
if "children" in page and "attachment" in page["children"]:
|
||||
attachments = self.confluence_client.get_attachments_for_page(
|
||||
page_id, expand="metadata"
|
||||
)
|
||||
|
||||
for attachment in attachments.get("results", []):
|
||||
# Process each attachment
|
||||
result = process_attachment(
|
||||
self.confluence_client,
|
||||
attachment,
|
||||
page_title,
|
||||
self.image_analysis_llm,
|
||||
)
|
||||
|
||||
if result.text:
|
||||
# Create a section for the attachment text
|
||||
attachment_section = Section(
|
||||
text=result.text,
|
||||
link=f"{page_url}#attachment-{attachment['id']}",
|
||||
image_file_name=result.file_name,
|
||||
)
|
||||
sections.append(attachment_section)
|
||||
elif result.error:
|
||||
logger.warning(
|
||||
f"Error processing attachment '{attachment.get('title')}': {result.error}"
|
||||
)
|
||||
|
||||
# Extract metadata
|
||||
metadata = {}
|
||||
if "space" in page:
|
||||
metadata["space"] = page["space"].get("name", "")
|
||||
|
||||
# Extract labels
|
||||
labels = []
|
||||
if "metadata" in page and "labels" in page["metadata"]:
|
||||
for label in page["metadata"]["labels"].get("results", []):
|
||||
labels.append(label.get("name", ""))
|
||||
if labels:
|
||||
metadata["labels"] = labels
|
||||
|
||||
# Extract owners
|
||||
primary_owners = []
|
||||
if "version" in page and "by" in page["version"]:
|
||||
author = page["version"]["by"]
|
||||
display_name = author.get("displayName", "Unknown")
|
||||
primary_owners.append(BasicExpertInfo(display_name=display_name))
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=build_confluence_document_id(self.wiki_base, page_id, self.is_cloud),
|
||||
sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page_title,
|
||||
metadata=metadata,
|
||||
doc_updated_at=datetime_from_string(page["version"]["when"]),
|
||||
primary_owners=primary_owners if primary_owners else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting page {page.get('id', 'unknown')}: {e}")
|
||||
if not self.continue_on_failure:
|
||||
raise
|
||||
return None
|
||||
|
||||
# Get space name
|
||||
doc_metadata: dict[str, str | list[str]] = {
|
||||
"Wiki Space Name": confluence_object["space"]["name"]
|
||||
}
|
||||
|
||||
# Get labels
|
||||
label_dicts = (
|
||||
confluence_object.get("metadata", {}).get("labels", {}).get("results", [])
|
||||
)
|
||||
page_labels = [label.get("name") for label in label_dicts if label.get("name")]
|
||||
if page_labels:
|
||||
doc_metadata["labels"] = page_labels
|
||||
|
||||
# Get last modified and author email
|
||||
version_dict = confluence_object.get("version", {})
|
||||
last_modified = (
|
||||
datetime_from_string(version_dict.get("when"))
|
||||
if version_dict.get("when")
|
||||
else None
|
||||
)
|
||||
author_email = version_dict.get("by", {}).get("email")
|
||||
|
||||
title = confluence_object.get("title", "Untitled Document")
|
||||
|
||||
return Document(
|
||||
id=object_url,
|
||||
sections=[Section(link=object_url, text=object_text)],
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=title,
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=author_email)] if author_email else None
|
||||
),
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
|
||||
def _fetch_document_batches(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Yields batches of Documents. For each page:
|
||||
- Create a Document with 1 Section for the page text/comments
|
||||
- Then fetch attachments. For each attachment:
|
||||
- Attempt to convert it with convert_attachment_to_content(...)
|
||||
- If successful, create a new Section with the extracted text or summary.
|
||||
"""
|
||||
doc_batch: list[Document] = []
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self._construct_page_query(start, end)
|
||||
logger.debug(f"page_query: {page_query}")
|
||||
# Fetch pages as Documents
|
||||
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand=",".join(_PAGE_EXPANSION_FIELDS),
|
||||
limit=self.batch_size,
|
||||
):
|
||||
logger.debug(f"_fetch_document_batches: {page['id']}")
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
# Build doc from page
|
||||
doc = self._convert_page_to_document(page)
|
||||
if not doc:
|
||||
continue
|
||||
|
||||
# Now get attachments for that page:
|
||||
attachment_query = self._construct_attachment_query(page["id"])
|
||||
# We'll use the page's XML to provide context if we summarize an image
|
||||
confluence_xml = page.get("body", {}).get("storage", {}).get("value", "")
|
||||
|
||||
# Fetch attachments as Documents
|
||||
for confluence_page_id in confluence_page_ids:
|
||||
attachment_query = self._construct_attachment_query(confluence_page_id)
|
||||
# TODO: maybe should add time filter as well?
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_query,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
attachment["metadata"].get("mediaType", "")
|
||||
if not validate_attachment_filetype(
|
||||
attachment, self.image_analysis_llm
|
||||
):
|
||||
continue
|
||||
|
||||
# Attempt to get textual content or image summarization:
|
||||
try:
|
||||
logger.info(f"Processing attachment: {attachment['title']}")
|
||||
response = convert_attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_context=confluence_xml,
|
||||
llm=self.image_analysis_llm,
|
||||
)
|
||||
if response is None:
|
||||
continue
|
||||
|
||||
content_text, file_storage_name = response
|
||||
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
if content_text:
|
||||
doc.sections.append(
|
||||
Section(
|
||||
text=content_text,
|
||||
link=object_url,
|
||||
image_file_name=file_storage_name,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to extract/summarize attachment {attachment['title']}",
|
||||
exc_info=e,
|
||||
)
|
||||
if not self.continue_on_failure:
|
||||
raise
|
||||
|
||||
doc_batch.append(doc)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
@@ -328,55 +420,63 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""
|
||||
Return 'slim' docs (IDs + minimal permission data).
|
||||
Does not fetch actual text. Used primarily for incremental permission sync.
|
||||
"""
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
|
||||
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
|
||||
|
||||
# Query pages
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
# If the page has restrictions, add them to the perm_sync_data
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
page_restrictions = page.get("restrictions")
|
||||
page_space_key = page.get("space", {}).get("key")
|
||||
page_ancestors = page.get("ancestors", [])
|
||||
|
||||
page_perm_sync_data = {
|
||||
"restrictions": page_restrictions or {},
|
||||
"space_key": page_space_key,
|
||||
"ancestors": page_ancestors or [],
|
||||
"ancestors": page_ancestors,
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base,
|
||||
page["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
),
|
||||
perm_sync_data=page_perm_sync_data,
|
||||
)
|
||||
)
|
||||
|
||||
# Query attachments for each page
|
||||
attachment_query = self._construct_attachment_query(page["id"])
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_query,
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
if not validate_attachment_filetype(attachment):
|
||||
# If you skip images, you'll skip them in the permission sync
|
||||
attachment["metadata"].get("mediaType", "")
|
||||
if not validate_attachment_filetype(
|
||||
attachment, self.image_analysis_llm
|
||||
):
|
||||
continue
|
||||
attachment_restrictions = attachment.get("restrictions")
|
||||
|
||||
attachment_restrictions = attachment.get("restrictions", {})
|
||||
if not attachment_restrictions:
|
||||
attachment_restrictions = page_restrictions
|
||||
attachment_restrictions = page_restrictions or {}
|
||||
|
||||
attachment_space_key = attachment.get("space", {}).get("key")
|
||||
if not attachment_space_key:
|
||||
attachment_space_key = page_space_key
|
||||
|
||||
attachment_perm_sync_data = {
|
||||
"restrictions": attachment_restrictions or {},
|
||||
"restrictions": attachment_restrictions,
|
||||
"space_key": attachment_space_key,
|
||||
}
|
||||
|
||||
@@ -390,16 +490,16 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
perm_sync_data=attachment_perm_sync_data,
|
||||
)
|
||||
)
|
||||
|
||||
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
|
||||
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
|
||||
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
|
||||
|
||||
if callback and callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
)
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
|
||||
yield doc_metadata_list
|
||||
@@ -420,11 +520,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
raise InsufficientPermissionsError(
|
||||
"Insufficient permissions to access Confluence resources (HTTP 403)."
|
||||
)
|
||||
raise UnexpectedError(
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected Confluence error (status={status_code}): {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise UnexpectedError(
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected error while validating Confluence settings: {e}"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,19 +1,37 @@
|
||||
import math
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
from urllib.parse import quote
|
||||
|
||||
import bs4
|
||||
from atlassian import Confluence # type:ignore
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from requests import HTTPError
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.connectors.confluence.utils import _handle_http_error
|
||||
from onyx.connectors.confluence.utils import confluence_refresh_tokens
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.confluence.utils import validate_attachment_filetype
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -22,12 +40,14 @@ logger = setup_logger()
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-76433
|
||||
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
|
||||
_REPLACEMENT_EXPANSIONS = "body.view.value"
|
||||
|
||||
_USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
@@ -43,124 +63,355 @@ class ConfluenceUser(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
def _handle_http_error(e: HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
BACKOFF = 2
|
||||
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
delay = retry_after
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
return delay_until
|
||||
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
_MINIMUM_PAGINATION_LIMIT = 50
|
||||
|
||||
|
||||
class OnyxConfluence(Confluence):
|
||||
class OnyxConfluence:
|
||||
"""
|
||||
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
|
||||
This is a custom Confluence class that:
|
||||
|
||||
A. overrides the default Confluence class to add a custom CQL method.
|
||||
B.
|
||||
This is necessary because the default Confluence class does not properly support cql expansions.
|
||||
All methods are automatically wrapped with handle_confluence_rate_limit.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
CREDENTIAL_PREFIX = "connector:confluence:credential"
|
||||
CREDENTIAL_TTL = 300 # 5 min
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
is_cloud: bool,
|
||||
url: str,
|
||||
credentials_provider: CredentialsProviderInterface,
|
||||
) -> None:
|
||||
self._is_cloud = is_cloud
|
||||
self._url = url.rstrip("/")
|
||||
self._credentials_provider = credentials_provider
|
||||
|
||||
self.redis_client: Redis | None = None
|
||||
self.static_credentials: dict[str, Any] | None = None
|
||||
if self._credentials_provider.is_dynamic():
|
||||
self.redis_client = get_redis_client(
|
||||
tenant_id=credentials_provider.get_tenant_id()
|
||||
)
|
||||
else:
|
||||
self.static_credentials = self._credentials_provider.get_credentials()
|
||||
|
||||
self._confluence = Confluence(url)
|
||||
self.credential_key: str = (
|
||||
self.CREDENTIAL_PREFIX
|
||||
+ f":credential_{self._credentials_provider.get_provider_key()}"
|
||||
)
|
||||
|
||||
self._kwargs: Any = None
|
||||
|
||||
self.shared_base_kwargs = {
|
||||
"api_version": "cloud" if is_cloud else "latest",
|
||||
"backoff_and_retry": True,
|
||||
"cloud": is_cloud,
|
||||
}
|
||||
|
||||
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
|
||||
"""credential_json - the current json credentials
|
||||
Returns a tuple
|
||||
1. The up to date credentials
|
||||
2. True if the credentials were updated
|
||||
|
||||
This method is intended to be used within a distributed lock.
|
||||
Lock, call this, update credentials if the tokens were refreshed, then release
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
wrap it with handle_confluence_rate_limit.
|
||||
"""
|
||||
for attr_name in dir(self):
|
||||
if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
|
||||
setattr(
|
||||
self,
|
||||
attr_name,
|
||||
handle_confluence_rate_limit(getattr(self, attr_name)),
|
||||
# static credentials are preloaded, so no locking/redis required
|
||||
if self.static_credentials:
|
||||
return self.static_credentials, False
|
||||
|
||||
if not self.redis_client:
|
||||
raise RuntimeError("self.redis_client is None")
|
||||
|
||||
# dynamic credentials need locking
|
||||
# check redis first, then fallback to the DB
|
||||
credential_raw = self.redis_client.get(self.credential_key)
|
||||
if credential_raw is not None:
|
||||
credential_bytes = cast(bytes, credential_raw)
|
||||
credential_str = credential_bytes.decode("utf-8")
|
||||
credential_json: dict[str, Any] = json.loads(credential_str)
|
||||
else:
|
||||
credential_json = self._credentials_provider.get_credentials()
|
||||
|
||||
if "confluence_refresh_token" not in credential_json:
|
||||
# static credentials ... cache them permanently and return
|
||||
self.static_credentials = credential_json
|
||||
return credential_json, False
|
||||
|
||||
if not OAUTH_CONFLUENCE_CLOUD_CLIENT_ID:
|
||||
raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_ID must be set!")
|
||||
|
||||
if not OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET:
|
||||
raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET must be set!")
|
||||
|
||||
# check if we should refresh tokens. we're deciding to refresh halfway
|
||||
# to expiration
|
||||
now = datetime.now(timezone.utc)
|
||||
created_at = datetime.fromisoformat(credential_json["created_at"])
|
||||
expires_in: int = credential_json["expires_in"]
|
||||
renew_at = created_at + timedelta(seconds=expires_in // 2)
|
||||
if now <= renew_at:
|
||||
# cached/current credentials are reasonably up to date
|
||||
return credential_json, False
|
||||
|
||||
# we need to refresh
|
||||
logger.info("Renewing Confluence Cloud credentials...")
|
||||
new_credentials = confluence_refresh_tokens(
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID,
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET,
|
||||
credential_json["cloud_id"],
|
||||
credential_json["confluence_refresh_token"],
|
||||
)
|
||||
|
||||
# store the new credentials to redis and to the db thru the provider
|
||||
# redis: we use a 5 min TTL because we are given a 10 minute grace period
|
||||
# when keys are rotated. it's easier to expire the cached credentials
|
||||
# reasonably frequently rather than trying to handle strong synchronization
|
||||
# between the db and redis everywhere the credentials might be updated
|
||||
new_credential_str = json.dumps(new_credentials)
|
||||
self.redis_client.set(
|
||||
self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL
|
||||
)
|
||||
self._credentials_provider.set_credentials(new_credentials)
|
||||
|
||||
return new_credentials, True
|
||||
|
||||
@staticmethod
|
||||
def _make_oauth2_dict(credentials: dict[str, Any]) -> dict[str, Any]:
|
||||
oauth2_dict: dict[str, Any] = {}
|
||||
if "confluence_refresh_token" in credentials:
|
||||
oauth2_dict["client_id"] = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
oauth2_dict["token"] = {}
|
||||
oauth2_dict["token"]["access_token"] = credentials[
|
||||
"confluence_access_token"
|
||||
]
|
||||
return oauth2_dict
|
||||
|
||||
def _probe_connection(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
|
||||
|
||||
with self._credentials_provider:
|
||||
credentials, _ = self._renew_credentials()
|
||||
|
||||
# probe connection with direct client, no retries
|
||||
if "confluence_refresh_token" in credentials:
|
||||
logger.info("Probing Confluence with OAuth Access Token.")
|
||||
|
||||
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(
|
||||
credentials
|
||||
)
|
||||
url = (
|
||||
f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
|
||||
)
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url, oauth2=oauth2_dict, **merged_kwargs
|
||||
)
|
||||
else:
|
||||
logger.info("Probing Confluence with Personal Access Token.")
|
||||
url = self._url
|
||||
if self._is_cloud:
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
username=credentials["confluence_username"],
|
||||
password=credentials["confluence_access_token"],
|
||||
**merged_kwargs,
|
||||
)
|
||||
else:
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
token=credentials["confluence_access_token"],
|
||||
**merged_kwargs,
|
||||
)
|
||||
|
||||
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {url}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
)
|
||||
|
||||
logger.info("Confluence probe succeeded.")
|
||||
|
||||
def _initialize_connection(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called externally to init the connection in a thread safe manner."""
|
||||
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
|
||||
with self._credentials_provider:
|
||||
credentials, _ = self._renew_credentials()
|
||||
self._confluence = self._initialize_connection_helper(
|
||||
credentials, **merged_kwargs
|
||||
)
|
||||
self._kwargs = merged_kwargs
|
||||
|
||||
def _initialize_connection_helper(
|
||||
self,
|
||||
credentials: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> Confluence:
|
||||
"""Called internally to init the connection. Distributed locking
|
||||
to prevent multiple threads from modifying the credentials
|
||||
must be handled around this function."""
|
||||
|
||||
confluence = None
|
||||
|
||||
# probe connection with direct client, no retries
|
||||
if "confluence_refresh_token" in credentials:
|
||||
logger.info("Connecting to Confluence Cloud with OAuth Access Token.")
|
||||
|
||||
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(credentials)
|
||||
url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
|
||||
confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs)
|
||||
else:
|
||||
logger.info("Connecting to Confluence with Personal Access Token.")
|
||||
if self._is_cloud:
|
||||
confluence = Confluence(
|
||||
url=self._url,
|
||||
username=credentials["confluence_username"],
|
||||
password=credentials["confluence_access_token"],
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
confluence = Confluence(
|
||||
url=self._url,
|
||||
token=credentials["confluence_access_token"],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return confluence
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def _make_rate_limited_confluence_method(
|
||||
self, name: str, credential_provider: CredentialsProviderInterface | None
|
||||
) -> Callable[..., Any]:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
try:
|
||||
if credential_provider:
|
||||
with credential_provider:
|
||||
credentials, renewed = self._renew_credentials()
|
||||
if renewed:
|
||||
self._confluence = self._initialize_connection_helper(
|
||||
credentials, **self._kwargs
|
||||
)
|
||||
attr = getattr(self._confluence, name, None)
|
||||
if attr is None:
|
||||
# The underlying Confluence client doesn't have this attribute
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
return attr(*args, **kwargs)
|
||||
else:
|
||||
attr = getattr(self._confluence, name, None)
|
||||
if attr is None:
|
||||
# The underlying Confluence client doesn't have this attribute
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return wrapped_call
|
||||
|
||||
# def _wrap_methods(self) -> None:
|
||||
# """
|
||||
# For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
# wrap it with handle_confluence_rate_limit.
|
||||
# """
|
||||
# for attr_name in dir(self):
|
||||
# if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
|
||||
# setattr(
|
||||
# self,
|
||||
# attr_name,
|
||||
# handle_confluence_rate_limit(getattr(self, attr_name)),
|
||||
# )
|
||||
|
||||
# def _ensure_token_valid(self) -> None:
|
||||
# if self._token_is_expired():
|
||||
# self._refresh_token()
|
||||
# # Re-init the Confluence client with the originally stored args
|
||||
# self._confluence = Confluence(self._url, *self._args, **self._kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Dynamically intercept attribute/method access."""
|
||||
attr = getattr(self._confluence, name, None)
|
||||
if attr is None:
|
||||
# The underlying Confluence client doesn't have this attribute
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
# If it's not a method, just return it after ensuring token validity
|
||||
if not callable(attr):
|
||||
return attr
|
||||
|
||||
# skip methods that start with "_"
|
||||
if name.startswith("_"):
|
||||
return attr
|
||||
|
||||
# wrap the method with our retry handler
|
||||
rate_limited_method: Callable[
|
||||
..., Any
|
||||
] = self._make_rate_limited_confluence_method(name, self._credentials_provider)
|
||||
|
||||
def wrapped_method(*args: Any, **kwargs: Any) -> Any:
|
||||
return rate_limited_method(*args, **kwargs)
|
||||
|
||||
return wrapped_method
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
|
||||
@@ -507,63 +758,212 @@ class OnyxConfluence(Confluence):
|
||||
return response
|
||||
|
||||
|
||||
def _validate_connector_configuration(
|
||||
credentials: dict[str, Any],
|
||||
is_cloud: bool,
|
||||
wiki_base: str,
|
||||
) -> None:
|
||||
# test connection with direct client, no retries
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
url=wiki_base.rstrip("/"),
|
||||
username=credentials["confluence_username"] if is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=6,
|
||||
max_backoff_seconds=10,
|
||||
def get_user_email_from_username__server(
|
||||
confluence_client: OnyxConfluence, user_name: str
|
||||
) -> str | None:
|
||||
global _USER_EMAIL_CACHE
|
||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||
try:
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
# For now, we'll just return None and log a warning. This means
|
||||
# we will keep retrying to get the email every group sync.
|
||||
email = None
|
||||
# We may want to just return a string that indicates failure so we dont
|
||||
# keep retrying
|
||||
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
user_id (str): The user id (i.e: the account-id or userkey)
|
||||
confluence_client (Confluence): The Confluence Client
|
||||
|
||||
Returns:
|
||||
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
|
||||
"""
|
||||
global _USER_ID_TO_DISPLAY_NAME_CACHE
|
||||
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_userkey(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
if not found_display_name:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_accountid(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
|
||||
|
||||
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return None
|
||||
|
||||
if "api.atlassian.com" in confluence_client.url:
|
||||
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
|
||||
if not parent_content_id:
|
||||
logger.warning(
|
||||
"parent_content_id is required to download attachments from Confluence Cloud!"
|
||||
)
|
||||
return None
|
||||
|
||||
download_link = (
|
||||
confluence_client.url
|
||||
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
|
||||
)
|
||||
else:
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
|
||||
# why are we using session.get here? we probably won't retry these ... is that ok?
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
return extracted_text
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {wiki_base}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
|
||||
Args:
|
||||
confluence_object (dict): The confluence object as a dict
|
||||
confluence_client (Confluence): Confluence client
|
||||
fetched_titles (set[str]): The titles of the pages that have already been fetched
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
body = confluence_object["body"]
|
||||
object_html = body.get("storage", body.get("view", {})).get("value")
|
||||
|
||||
soup = bs4.BeautifulSoup(object_html, "html.parser")
|
||||
for user in soup.findAll("ri:user"):
|
||||
user_id = (
|
||||
user.attrs["ri:account-id"]
|
||||
if "ri:account-id" in user.attrs
|
||||
else user.get("ri:userkey")
|
||||
)
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
|
||||
)
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
|
||||
for html_page_reference in soup.findAll("ac:structured-macro"):
|
||||
# Here, we only want to process page within page macros
|
||||
if html_page_reference.attrs.get("ac:name") != "include":
|
||||
continue
|
||||
|
||||
page_data = html_page_reference.find("ri:page")
|
||||
if not page_data:
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because because page data is missing"
|
||||
)
|
||||
continue
|
||||
|
||||
page_title = page_data.attrs.get("ri:content-title")
|
||||
if not page_title:
|
||||
# only fetch pages that have a title
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because it has no title"
|
||||
)
|
||||
continue
|
||||
|
||||
if page_title in fetched_titles:
|
||||
# prevent recursive fetching of pages
|
||||
logger.debug(f"Skipping {page_title} because it has already been fetched")
|
||||
continue
|
||||
|
||||
fetched_titles.add(page_title)
|
||||
|
||||
# Wrap this in a try-except because there are some pages that might not exist
|
||||
try:
|
||||
page_query = f"type=page and title='{quote(page_title)}'"
|
||||
|
||||
page_contents: dict[str, Any] | None = None
|
||||
# Confluence enforces title uniqueness, so we should only get one result here
|
||||
for page in confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand="body.storage.value",
|
||||
limit=1,
|
||||
):
|
||||
page_contents = page
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error getting page contents for object {confluence_object}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not page_contents:
|
||||
continue
|
||||
|
||||
text_from_page = extract_text_from_confluence_html(
|
||||
confluence_client=confluence_client,
|
||||
confluence_object=page_contents,
|
||||
fetched_titles=fetched_titles,
|
||||
)
|
||||
|
||||
html_page_reference.replaceWith(text_from_page)
|
||||
|
||||
def build_confluence_client(
|
||||
credentials: dict[str, Any],
|
||||
is_cloud: bool,
|
||||
wiki_base: str,
|
||||
) -> OnyxConfluence:
|
||||
try:
|
||||
_validate_connector_configuration(
|
||||
credentials=credentials,
|
||||
is_cloud=is_cloud,
|
||||
wiki_base=wiki_base,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(str(e))
|
||||
for html_link_body in soup.findAll("ac:link-body"):
|
||||
# This extracts the text from inline links in the page so they can be
|
||||
# represented in the document text as plain text
|
||||
try:
|
||||
text_from_link = html_link_body.text
|
||||
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing ac:link-body: {e}")
|
||||
|
||||
return OnyxConfluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
# Remove trailing slash from wiki_base if present
|
||||
url=wiki_base.rstrip("/"),
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=credentials["confluence_username"] if is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=10,
|
||||
max_backoff_seconds=60,
|
||||
cloud=is_cloud,
|
||||
)
|
||||
return format_document_soup(soup)
|
||||
|
||||
@@ -1,239 +1,280 @@
|
||||
import io
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.configs.constants import FileOrigin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import PGFileStore
|
||||
from onyx.db.pg_file_store import create_populate_lobj
|
||||
from onyx.db.pg_file_store import save_bytes_to_pgfilestore
|
||||
from onyx.db.pg_file_store import upsert_pgfilestore
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.file_validation import is_valid_image_type
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
|
||||
def get_user_email_from_username__server(
|
||||
confluence_client: "OnyxConfluence", user_name: str
|
||||
) -> str | None:
|
||||
global _USER_EMAIL_CACHE
|
||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||
try:
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
# For now, we'll just return None and log a warning. This means
|
||||
# we will keep retrying to get the email every group sync.
|
||||
email = None
|
||||
# We may want to just return a string that indicates failure so we dont
|
||||
# keep retrying
|
||||
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
refresh_token: str
|
||||
scope: str
|
||||
|
||||
|
||||
_USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
user_id (str): The user id (i.e: the account-id or userkey)
|
||||
confluence_client (Confluence): The Confluence Client
|
||||
|
||||
Returns:
|
||||
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
|
||||
def validate_attachment_filetype(
|
||||
attachment: dict[str, Any], llm: LLM | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
global _USER_ID_TO_DISPLAY_NAME_CACHE
|
||||
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_userkey(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
if not found_display_name:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_accountid(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
|
||||
|
||||
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: "OnyxConfluence",
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
|
||||
Args:
|
||||
confluence_object (dict): The confluence object as a dict
|
||||
confluence_client (Confluence): Confluence client
|
||||
fetched_titles (set[str]): The titles of the pages that have already been fetched
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
Validates if the attachment is a supported file type.
|
||||
If LLM is provided, also checks if it's an image that can be processed.
|
||||
"""
|
||||
body = confluence_object["body"]
|
||||
object_html = body.get("storage", body.get("view", {})).get("value")
|
||||
attachment.get("metadata", {})
|
||||
media_type = attachment.get("metadata", {}).get("mediaType", "")
|
||||
|
||||
soup = bs4.BeautifulSoup(object_html, "html.parser")
|
||||
for user in soup.findAll("ri:user"):
|
||||
user_id = (
|
||||
user.attrs["ri:account-id"]
|
||||
if "ri:account-id" in user.attrs
|
||||
else user.get("ri:userkey")
|
||||
if media_type.startswith("image/"):
|
||||
return llm is not None and is_valid_image_type(media_type)
|
||||
|
||||
# For non-image files, check if we support the extension
|
||||
title = attachment.get("title", "")
|
||||
extension = Path(title).suffix.lstrip(".").lower() if "." in title else ""
|
||||
return extension in ["pdf", "doc", "docx", "txt", "md", "rtf"]
|
||||
|
||||
|
||||
class AttachmentProcessingResult(BaseModel):
|
||||
"""
|
||||
A container for results after processing a Confluence attachment.
|
||||
'text' is the textual content of the attachment.
|
||||
'file_name' is the final file name used in PGFileStore to store the content.
|
||||
'error' holds an exception or string if something failed.
|
||||
"""
|
||||
|
||||
text: str | None
|
||||
file_name: str | None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def _download_attachment(
|
||||
confluence_client: "OnyxConfluence", attachment: dict[str, Any]
|
||||
) -> bytes | None:
|
||||
"""
|
||||
Retrieves the raw bytes of an attachment from Confluence. Returns None on error.
|
||||
"""
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
resp = confluence_client._session.get(download_link)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with status code {resp.status_code}"
|
||||
)
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
|
||||
)
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
|
||||
for html_page_reference in soup.findAll("ac:structured-macro"):
|
||||
# Here, we only want to process page within page macros
|
||||
if html_page_reference.attrs.get("ac:name") != "include":
|
||||
continue
|
||||
|
||||
page_data = html_page_reference.find("ri:page")
|
||||
if not page_data:
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because because page data is missing"
|
||||
)
|
||||
continue
|
||||
|
||||
page_title = page_data.attrs.get("ri:content-title")
|
||||
if not page_title:
|
||||
# only fetch pages that have a title
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because it has no title"
|
||||
)
|
||||
continue
|
||||
|
||||
if page_title in fetched_titles:
|
||||
# prevent recursive fetching of pages
|
||||
logger.debug(f"Skipping {page_title} because it has already been fetched")
|
||||
continue
|
||||
|
||||
fetched_titles.add(page_title)
|
||||
|
||||
# Wrap this in a try-except because there are some pages that might not exist
|
||||
try:
|
||||
page_query = f"type=page and title='{quote(page_title)}'"
|
||||
|
||||
page_contents: dict[str, Any] | None = None
|
||||
# Confluence enforces title uniqueness, so we should only get one result here
|
||||
for page in confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand="body.storage.value",
|
||||
limit=1,
|
||||
):
|
||||
page_contents = page
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error getting page contents for object {confluence_object}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not page_contents:
|
||||
continue
|
||||
|
||||
text_from_page = extract_text_from_confluence_html(
|
||||
confluence_client=confluence_client,
|
||||
confluence_object=page_contents,
|
||||
fetched_titles=fetched_titles,
|
||||
)
|
||||
|
||||
html_page_reference.replaceWith(text_from_page)
|
||||
|
||||
for html_link_body in soup.findAll("ac:link-body"):
|
||||
# This extracts the text from inline links in the page so they can be
|
||||
# represented in the document text as plain text
|
||||
try:
|
||||
text_from_link = html_link_body.text
|
||||
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing ac:link-body: {e}")
|
||||
|
||||
return format_document_soup(soup)
|
||||
return None
|
||||
return resp.content
|
||||
|
||||
|
||||
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
return attachment["metadata"]["mediaType"] not in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
def process_attachment(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return None
|
||||
page_context: str,
|
||||
llm: LLM | None,
|
||||
) -> AttachmentProcessingResult:
|
||||
"""
|
||||
Processes a Confluence attachment. If it's a document, extracts text,
|
||||
or if it's an image and an LLM is available, summarizes it. Returns a structured result.
|
||||
"""
|
||||
try:
|
||||
# Get the media type from the attachment metadata
|
||||
media_type = attachment.get("metadata", {}).get("mediaType", "")
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
# Validate the attachment type
|
||||
if not validate_attachment_filetype(attachment, llm):
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_name=None,
|
||||
error=f"Unsupported file type: {media_type}",
|
||||
)
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
# Download the attachment
|
||||
raw_bytes = _download_attachment(confluence_client, attachment)
|
||||
if raw_bytes is None:
|
||||
return AttachmentProcessingResult(
|
||||
text=None, file_name=None, error="Failed to download attachment"
|
||||
)
|
||||
|
||||
# Process image attachments with LLM if available
|
||||
if media_type.startswith("image/") and llm:
|
||||
return _process_image_attachment(
|
||||
confluence_client, attachment, page_context, llm, raw_bytes, media_type
|
||||
)
|
||||
|
||||
# Process document attachments
|
||||
try:
|
||||
text = extract_file_text(
|
||||
file=BytesIO(raw_bytes),
|
||||
file_name=attachment["title"],
|
||||
)
|
||||
|
||||
# Skip if the text is too long
|
||||
if len(text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_name=None,
|
||||
error=f"Attachment text too long: {len(text)} chars",
|
||||
)
|
||||
|
||||
return AttachmentProcessingResult(text=text, file_name=None, error=None)
|
||||
except Exception as e:
|
||||
return AttachmentProcessingResult(
|
||||
text=None, file_name=None, error=f"Failed to extract text: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return AttachmentProcessingResult(
|
||||
text=None, file_name=None, error=f"Failed to process attachment: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
|
||||
def _process_image_attachment(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
page_context: str,
|
||||
llm: LLM,
|
||||
raw_bytes: bytes,
|
||||
media_type: str,
|
||||
) -> AttachmentProcessingResult:
|
||||
"""Process an image attachment by saving it and generating a summary."""
|
||||
try:
|
||||
# Use the standardized image storage and section creation
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, file_name = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=raw_bytes,
|
||||
file_name=Path(attachment["id"]).name,
|
||||
display_name=attachment["title"],
|
||||
media_type=media_type,
|
||||
llm=llm,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
|
||||
return AttachmentProcessingResult(
|
||||
text=section.text, file_name=file_name, error=None
|
||||
)
|
||||
except Exception as e:
|
||||
msg = f"Image summarization failed for {attachment['title']}: {e}"
|
||||
logger.error(msg, exc_info=e)
|
||||
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
|
||||
|
||||
|
||||
def _process_text_attachment(
|
||||
attachment: dict[str, Any],
|
||||
raw_bytes: bytes,
|
||||
media_type: str,
|
||||
) -> AttachmentProcessingResult:
|
||||
"""Process a text-based attachment by extracting its content."""
|
||||
try:
|
||||
extracted_text = extract_file_text(
|
||||
io.BytesIO(raw_bytes),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
msg = f"Failed to extract text for '{attachment['title']}': {e}"
|
||||
logger.error(msg, exc_info=e)
|
||||
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
|
||||
|
||||
# Check length constraints
|
||||
if extracted_text is None or len(extracted_text) == 0:
|
||||
msg = f"No text extracted for {attachment['title']}"
|
||||
logger.warning(msg)
|
||||
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
msg = (
|
||||
f"Skipping attachment {attachment['title']} due to char count "
|
||||
f"({len(extracted_text)} > {CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD})"
|
||||
)
|
||||
logger.warning(msg)
|
||||
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
|
||||
|
||||
# Save the attachment
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
saved_record = save_bytes_to_pgfilestore(
|
||||
db_session=db_session,
|
||||
raw_bytes=raw_bytes,
|
||||
media_type=media_type,
|
||||
identifier=attachment["id"],
|
||||
display_name=attachment["title"],
|
||||
)
|
||||
except Exception as e:
|
||||
msg = f"Failed to save attachment '{attachment['title']}' to PG: {e}"
|
||||
logger.error(msg, exc_info=e)
|
||||
return AttachmentProcessingResult(
|
||||
text=extracted_text, file_name=None, error=msg
|
||||
)
|
||||
|
||||
return AttachmentProcessingResult(
|
||||
text=extracted_text, file_name=saved_record.file_name, error=None
|
||||
)
|
||||
|
||||
|
||||
def convert_attachment_to_content(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
page_context: str,
|
||||
llm: LLM | None,
|
||||
) -> tuple[str | None, str | None] | None:
|
||||
"""
|
||||
Facade function which:
|
||||
1. Validates attachment type
|
||||
2. Extracts or summarizes content
|
||||
3. Returns (content_text, stored_file_name) or None if we should skip it
|
||||
"""
|
||||
media_type = attachment["metadata"]["mediaType"]
|
||||
# Quick check for unsupported types:
|
||||
if media_type.startswith("video/") or media_type == "application/gliffy+json":
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
f"Skipping unsupported attachment type: '{media_type}' for {attachment['title']}"
|
||||
)
|
||||
return None
|
||||
|
||||
return extracted_text
|
||||
result = process_attachment(confluence_client, attachment, page_context, llm)
|
||||
if result.error is not None:
|
||||
logger.warning(
|
||||
f"Attachment {attachment['title']} encountered error: {result.error}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Return the text and the file name
|
||||
return result.text, result.file_name
|
||||
|
||||
|
||||
def build_confluence_document_id(
|
||||
@@ -254,23 +295,6 @@ def build_confluence_document_id(
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def _extract_referenced_attachment_names(page_text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachments in use
|
||||
|
||||
Args:
|
||||
text (str): The page content
|
||||
|
||||
Returns:
|
||||
list[str]: List of filenames currently in use by the page text
|
||||
"""
|
||||
referenced_attachment_filenames = []
|
||||
soup = bs4.BeautifulSoup(page_text, "html.parser")
|
||||
for attachment in soup.findAll("ri:attachment"):
|
||||
referenced_attachment_filenames.append(attachment.attrs["ri:filename"])
|
||||
return referenced_attachment_filenames
|
||||
|
||||
|
||||
def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_object = datetime.fromisoformat(datetime_string)
|
||||
|
||||
@@ -284,6 +308,137 @@ def datetime_from_string(datetime_string: str) -> datetime:
|
||||
return datetime_object
|
||||
|
||||
|
||||
def confluence_refresh_tokens(
|
||||
client_id: str, client_secret: str, cloud_id: str, refresh_token: str
|
||||
) -> dict[str, Any]:
|
||||
# rotate the refresh and access token
|
||||
# Note that access tokens are only good for an hour in confluence cloud,
|
||||
# so we're going to have problems if the connector runs for longer
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/#use-a-refresh-token-to-get-another-access-token-and-refresh-token-pair
|
||||
response = requests.post(
|
||||
CONFLUENCE_OAUTH_TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
token_response = TokenResponse.model_validate_json(response.text)
|
||||
except Exception:
|
||||
raise RuntimeError("Confluence Cloud token refresh failed.")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=token_response.expires_in)
|
||||
|
||||
new_credentials: dict[str, Any] = {}
|
||||
new_credentials["confluence_access_token"] = token_response.access_token
|
||||
new_credentials["confluence_refresh_token"] = token_response.refresh_token
|
||||
new_credentials["created_at"] = now.isoformat()
|
||||
new_credentials["expires_at"] = expires_at.isoformat()
|
||||
new_credentials["expires_in"] = token_response.expires_in
|
||||
new_credentials["scope"] = token_response.scope
|
||||
new_credentials["cloud_id"] = cloud_id
|
||||
return new_credentials
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except requests.HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
BACKOFF = 2
|
||||
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
delay = retry_after
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
return delay_until
|
||||
|
||||
|
||||
def get_single_param_from_url(url: str, param: str) -> str | None:
|
||||
"""Get a parameter from a url"""
|
||||
parsed_url = urlparse(url)
|
||||
@@ -311,3 +466,37 @@ def update_param_in_path(path: str, param: str, value: str) -> str:
|
||||
+ "?"
|
||||
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
|
||||
)
|
||||
|
||||
|
||||
def attachment_to_file_record(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
db_session: Session,
|
||||
) -> tuple[PGFileStore, bytes]:
|
||||
"""Save an attachment to the file store and return the file record."""
|
||||
download_link = _attachment_to_download_link(confluence_client, attachment)
|
||||
image_data = confluence_client.get(
|
||||
download_link, absolute=True, not_json_response=True
|
||||
)
|
||||
|
||||
# Save image to file store
|
||||
file_name = f"confluence_attachment_{attachment['id']}"
|
||||
lobj_oid = create_populate_lobj(BytesIO(image_data), db_session)
|
||||
pgfilestore = upsert_pgfilestore(
|
||||
file_name=file_name,
|
||||
display_name=attachment["title"],
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_type=attachment["metadata"]["mediaType"],
|
||||
lobj_oid=lobj_oid,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
return pgfilestore, image_data
|
||||
|
||||
|
||||
def _attachment_to_download_link(
|
||||
confluence_client: "OnyxConfluence", attachment: dict[str, Any]
|
||||
) -> str:
|
||||
"""Extracts the download link to images."""
|
||||
return confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
135
backend/onyx/connectors/credentials_provider.py
Normal file
135
backend/onyx/connectors/credentials_provider.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import uuid
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import Credential
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
class OnyxDBCredentialsProvider(
|
||||
CredentialsProviderInterface["OnyxDBCredentialsProvider"]
|
||||
):
|
||||
"""Implementation to allow the connector to callback and update credentials in the db.
|
||||
Required in cases where credentials can rotate while the connector is running.
|
||||
"""
|
||||
|
||||
LOCK_TTL = 900 # TTL of the lock
|
||||
|
||||
def __init__(self, tenant_id: str, connector_name: str, credential_id: int):
|
||||
self._tenant_id = tenant_id
|
||||
self._connector_name = connector_name
|
||||
self._credential_id = credential_id
|
||||
|
||||
self.redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# lock used to prevent overlapping renewal of credentials
|
||||
self.lock_key = f"da_lock:connector:{connector_name}:credential_{credential_id}"
|
||||
self._lock: RedisLock = self.redis_client.lock(self.lock_key, self.LOCK_TTL)
|
||||
|
||||
def __enter__(self) -> "OnyxDBCredentialsProvider":
|
||||
acquired = self._lock.acquire(blocking_timeout=self.LOCK_TTL)
|
||||
if not acquired:
|
||||
raise RuntimeError(f"Could not acquire lock for key: {self.lock_key}")
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
"""Release the lock when exiting the context."""
|
||||
if self._lock and self._lock.owned():
|
||||
self._lock.release()
|
||||
|
||||
def get_tenant_id(self) -> str | None:
|
||||
return self._tenant_id
|
||||
|
||||
def get_provider_key(self) -> str:
|
||||
return str(self._credential_id)
|
||||
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
|
||||
credential = db_session.execute(
|
||||
select(Credential).where(Credential.id == self._credential_id)
|
||||
).scalar_one()
|
||||
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"No credential found: credential={self._credential_id}"
|
||||
)
|
||||
|
||||
return credential.credential_json
|
||||
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
|
||||
try:
|
||||
credential = db_session.execute(
|
||||
select(Credential)
|
||||
.where(Credential.id == self._credential_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"No credential found: credential={self._credential_id}"
|
||||
)
|
||||
|
||||
credential.credential_json = credential_json
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class OnyxStaticCredentialsProvider(
|
||||
CredentialsProviderInterface["OnyxStaticCredentialsProvider"]
|
||||
):
|
||||
"""Implementation (a very simple one!) to handle static credentials."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
connector_name: str,
|
||||
credential_json: dict[str, Any],
|
||||
):
|
||||
self._tenant_id = tenant_id
|
||||
self._connector_name = connector_name
|
||||
self._credential_json = credential_json
|
||||
|
||||
self._provider_key = str(uuid.uuid4())
|
||||
|
||||
def __enter__(self) -> "OnyxStaticCredentialsProvider":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def get_tenant_id(self) -> str | None:
|
||||
return self._tenant_id
|
||||
|
||||
def get_provider_key(self) -> str:
|
||||
return self._provider_key
|
||||
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
return self._credential_json
|
||||
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
self._credential_json = credential_json
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return False
|
||||
@@ -14,12 +14,15 @@ class ConnectorValidationError(ValidationError):
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class UnexpectedError(ValidationError):
|
||||
class UnexpectedValidationError(ValidationError):
|
||||
"""Raised when an unexpected error occurs during connector validation.
|
||||
|
||||
Unexpected errors don't necessarily mean the credential is invalid,
|
||||
but rather that there was an error during the validation process
|
||||
or we encountered a currently unhandled error case.
|
||||
|
||||
Currently, unexpected validation errors are defined as transient and should not be
|
||||
used to disable the connector.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "Unexpected error during connector validation"):
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.bookstack.connector import BookstackConnector
|
||||
from onyx.connectors.clickup.connector import ClickupConnector
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.discord.connector import DiscordConnector
|
||||
from onyx.connectors.discourse.connector import DiscourseConnector
|
||||
from onyx.connectors.document360.connector import Document360Connector
|
||||
@@ -32,6 +33,7 @@ from onyx.connectors.guru.connector import GuruConnector
|
||||
from onyx.connectors.hubspot.connector import HubSpotConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import EventConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -57,6 +59,7 @@ from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import backend_update_credential_json
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
from onyx.db.models import Credential
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
class ConnectorMissingException(Exception):
|
||||
@@ -167,10 +170,17 @@ def instantiate_connector(
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
|
||||
connector = connector_class(**connector_specific_config)
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
if isinstance(connector, CredentialsConnector):
|
||||
provider = OnyxDBCredentialsProvider(
|
||||
get_current_tenant_id(), str(source), credential.id
|
||||
)
|
||||
connector.set_credentials_provider(provider)
|
||||
else:
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
@@ -10,22 +10,23 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import detect_encoding
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_text_file_extension
|
||||
from onyx.file_processing.extract_file_text import is_valid_file_ext
|
||||
from onyx.file_processing.extract_file_text import load_files_from_zip
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -35,81 +36,115 @@ def _read_files_and_metadata(
|
||||
file_name: str,
|
||||
db_session: Session,
|
||||
) -> Iterator[tuple[str, IO, dict[str, Any]]]:
|
||||
"""Reads the file into IO, in the case of a zip file, yields each individual
|
||||
file contained within, also includes the metadata dict if packaged in the zip"""
|
||||
"""
|
||||
Reads the file from Postgres. If the file is a .zip, yields subfiles.
|
||||
"""
|
||||
extension = get_file_ext(file_name)
|
||||
metadata: dict[str, Any] = {}
|
||||
directory_path = os.path.dirname(file_name)
|
||||
|
||||
# Read file from Postgres store
|
||||
file_content = get_default_file_store(db_session).read_file(file_name, mode="b")
|
||||
|
||||
# If it's a zip, expand it
|
||||
if extension == ".zip":
|
||||
for file_info, file, metadata in load_files_from_zip(
|
||||
for file_info, subfile, metadata in load_files_from_zip(
|
||||
file_content, ignore_dirs=True
|
||||
):
|
||||
yield os.path.join(directory_path, file_info.filename), file, metadata
|
||||
yield os.path.join(directory_path, file_info.filename), subfile, metadata
|
||||
elif is_valid_file_ext(extension):
|
||||
yield file_name, file_content, metadata
|
||||
else:
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
|
||||
|
||||
def _create_image_section(
|
||||
llm: LLM | None,
|
||||
image_data: bytes,
|
||||
db_session: Session,
|
||||
parent_file_name: str,
|
||||
display_name: str,
|
||||
idx: int = 0,
|
||||
) -> tuple[Section, str | None]:
|
||||
"""
|
||||
Create a Section object for a single image and store the image in PGFileStore.
|
||||
If summarization is enabled and we have an LLM, summarize the image.
|
||||
|
||||
Returns:
|
||||
tuple: (Section object, file_name in PGFileStore or None if storage failed)
|
||||
"""
|
||||
# Create a unique file name for the embedded image
|
||||
file_name = f"{parent_file_name}_embedded_{idx}"
|
||||
|
||||
# Use the standardized utility to store the image and create a section
|
||||
return store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=image_data,
|
||||
file_name=file_name,
|
||||
display_name=display_name,
|
||||
llm=llm,
|
||||
file_origin=FileOrigin.OTHER,
|
||||
)
|
||||
|
||||
|
||||
def _process_file(
|
||||
file_name: str,
|
||||
file: IO[Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
pdf_pass: str | None = None,
|
||||
metadata: dict[str, Any] | None,
|
||||
pdf_pass: str | None,
|
||||
db_session: Session,
|
||||
llm: LLM | None,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Processes a single file, returning a list of Documents (typically one).
|
||||
Also handles embedded images if 'EMBEDDED_IMAGE_EXTRACTION_ENABLED' is true.
|
||||
"""
|
||||
extension = get_file_ext(file_name)
|
||||
if not is_valid_file_ext(extension):
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
|
||||
# Fetch the DB record so we know the ID for internal URL
|
||||
pg_record = get_pgfilestore_by_file_name(file_name=file_name, db_session=db_session)
|
||||
if not pg_record:
|
||||
logger.warning(f"No file record found for '{file_name}' in PG; skipping.")
|
||||
return []
|
||||
|
||||
file_metadata: dict[str, Any] = {}
|
||||
|
||||
if is_text_file_extension(file_name):
|
||||
encoding = detect_encoding(file)
|
||||
file_content_raw, file_metadata = read_text_file(
|
||||
file, encoding=encoding, ignore_onyx_metadata=False
|
||||
if not is_valid_file_ext(extension):
|
||||
logger.warning(
|
||||
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
|
||||
)
|
||||
return []
|
||||
|
||||
# Using the PDF reader function directly to pass in password cleanly
|
||||
elif extension == ".pdf" and pdf_pass is not None:
|
||||
file_content_raw, file_metadata = read_pdf_file(file=file, pdf_pass=pdf_pass)
|
||||
# Prepare doc metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
file_display_name = metadata.get("file_display_name") or os.path.basename(file_name)
|
||||
|
||||
else:
|
||||
file_content_raw = extract_file_text(
|
||||
file=file,
|
||||
file_name=file_name,
|
||||
break_on_unprocessable=True,
|
||||
)
|
||||
|
||||
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata
|
||||
|
||||
# add a prefix to avoid conflicts with other connectors
|
||||
doc_id = f"FILE_CONNECTOR__{file_name}"
|
||||
if metadata:
|
||||
doc_id = metadata.get("document_id") or doc_id
|
||||
|
||||
# If this is set, we will show this in the UI as the "name" of the file
|
||||
file_display_name = all_metadata.get("file_display_name") or os.path.basename(
|
||||
file_name
|
||||
)
|
||||
title = (
|
||||
all_metadata["title"] or "" if "title" in all_metadata else file_display_name
|
||||
)
|
||||
|
||||
time_updated = all_metadata.get("time_updated", datetime.now(timezone.utc))
|
||||
# Timestamps
|
||||
current_datetime = datetime.now(timezone.utc)
|
||||
time_updated = metadata.get("time_updated", current_datetime)
|
||||
if isinstance(time_updated, str):
|
||||
time_updated = time_str_to_utc(time_updated)
|
||||
|
||||
dt_str = all_metadata.get("doc_updated_at")
|
||||
dt_str = metadata.get("doc_updated_at")
|
||||
final_time_updated = time_str_to_utc(dt_str) if dt_str else time_updated
|
||||
|
||||
# Metadata tags separate from the Onyx specific fields
|
||||
# Collect owners
|
||||
p_owner_names = metadata.get("primary_owners")
|
||||
s_owner_names = metadata.get("secondary_owners")
|
||||
p_owners = (
|
||||
[BasicExpertInfo(display_name=name) for name in p_owner_names]
|
||||
if p_owner_names
|
||||
else None
|
||||
)
|
||||
s_owners = (
|
||||
[BasicExpertInfo(display_name=name) for name in s_owner_names]
|
||||
if s_owner_names
|
||||
else None
|
||||
)
|
||||
|
||||
# Additional tags we store as doc metadata
|
||||
metadata_tags = {
|
||||
k: v
|
||||
for k, v in all_metadata.items()
|
||||
for k, v in metadata.items()
|
||||
if k
|
||||
not in [
|
||||
"document_id",
|
||||
@@ -122,77 +157,142 @@ def _process_file(
|
||||
"file_display_name",
|
||||
"title",
|
||||
"connector_type",
|
||||
"pdf_password",
|
||||
]
|
||||
}
|
||||
|
||||
source_type_str = all_metadata.get("connector_type")
|
||||
source_type = DocumentSource(source_type_str) if source_type_str else None
|
||||
|
||||
p_owner_names = all_metadata.get("primary_owners")
|
||||
s_owner_names = all_metadata.get("secondary_owners")
|
||||
p_owners = (
|
||||
[BasicExpertInfo(display_name=name) for name in p_owner_names]
|
||||
if p_owner_names
|
||||
else None
|
||||
)
|
||||
s_owners = (
|
||||
[BasicExpertInfo(display_name=name) for name in s_owner_names]
|
||||
if s_owner_names
|
||||
else None
|
||||
source_type_str = metadata.get("connector_type")
|
||||
source_type = (
|
||||
DocumentSource(source_type_str) if source_type_str else DocumentSource.FILE
|
||||
)
|
||||
|
||||
doc_id = metadata.get("document_id") or f"FILE_CONNECTOR__{file_name}"
|
||||
title = metadata.get("title") or file_display_name
|
||||
|
||||
# 1) If the file itself is an image, handle that scenario quickly
|
||||
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"}
|
||||
if extension in IMAGE_EXTENSIONS:
|
||||
# Summarize or produce empty doc
|
||||
image_data = file.read()
|
||||
image_section, _ = _create_image_section(
|
||||
llm, image_data, db_session, pg_record.file_name, title
|
||||
)
|
||||
return [
|
||||
Document(
|
||||
id=doc_id,
|
||||
sections=[image_section],
|
||||
source=source_type,
|
||||
semantic_identifier=file_display_name,
|
||||
title=title,
|
||||
doc_updated_at=final_time_updated,
|
||||
primary_owners=p_owners,
|
||||
secondary_owners=s_owners,
|
||||
metadata=metadata_tags,
|
||||
)
|
||||
]
|
||||
|
||||
# 2) Otherwise: text-based approach. Possibly with embedded images if enabled.
|
||||
# (For example .docx with inline images).
|
||||
file.seek(0)
|
||||
text_content = ""
|
||||
embedded_images: list[tuple[bytes, str]] = []
|
||||
|
||||
text_content, embedded_images = extract_text_and_images(
|
||||
file=file,
|
||||
file_name=file_name,
|
||||
pdf_pass=pdf_pass,
|
||||
)
|
||||
|
||||
# Build sections: first the text as a single Section
|
||||
sections = []
|
||||
link_in_meta = metadata.get("link")
|
||||
if text_content.strip():
|
||||
sections.append(Section(link=link_in_meta, text=text_content.strip()))
|
||||
|
||||
# Then any extracted images from docx, etc.
|
||||
for idx, (img_data, img_name) in enumerate(embedded_images, start=1):
|
||||
# Store each embedded image as a separate file in PGFileStore
|
||||
# and create a section with the image summary
|
||||
image_section, _ = _create_image_section(
|
||||
llm,
|
||||
img_data,
|
||||
db_session,
|
||||
pg_record.file_name,
|
||||
f"{title} - image {idx}",
|
||||
idx,
|
||||
)
|
||||
sections.append(image_section)
|
||||
return [
|
||||
Document(
|
||||
id=doc_id,
|
||||
sections=[
|
||||
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
|
||||
],
|
||||
source=source_type or DocumentSource.FILE,
|
||||
sections=sections,
|
||||
source=source_type,
|
||||
semantic_identifier=file_display_name,
|
||||
title=title,
|
||||
doc_updated_at=final_time_updated,
|
||||
primary_owners=p_owners,
|
||||
secondary_owners=s_owners,
|
||||
# currently metadata just houses tags, other stuff like owners / updated at have dedicated fields
|
||||
metadata=metadata_tags,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class LocalFileConnector(LoadConnector):
|
||||
class LocalFileConnector(LoadConnector, VisionEnabledConnector):
|
||||
"""
|
||||
Connector that reads files from Postgres and yields Documents, including
|
||||
optional embedded image extraction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_locations: list[Path | str],
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.file_locations = [Path(file_location) for file_location in file_locations]
|
||||
self.file_locations = [str(loc) for loc in file_locations]
|
||||
self.batch_size = batch_size
|
||||
self.pdf_pass: str | None = None
|
||||
|
||||
# Initialize vision LLM using the mixin
|
||||
self.initialize_vision_llm()
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.pdf_pass = credentials.get("pdf_password")
|
||||
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Iterates over each file path, fetches from Postgres, tries to parse text
|
||||
or images, and yields Document batches.
|
||||
"""
|
||||
documents: list[Document] = []
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for file_path in self.file_locations:
|
||||
current_datetime = datetime.now(timezone.utc)
|
||||
files = _read_files_and_metadata(
|
||||
file_name=str(file_path), db_session=db_session
|
||||
|
||||
files_iter = _read_files_and_metadata(
|
||||
file_name=file_path,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
for file_name, file, metadata in files:
|
||||
for actual_file_name, file, metadata in files_iter:
|
||||
metadata["time_updated"] = metadata.get(
|
||||
"time_updated", current_datetime
|
||||
)
|
||||
documents.extend(
|
||||
_process_file(file_name, file, metadata, self.pdf_pass)
|
||||
new_docs = _process_file(
|
||||
file_name=actual_file_name,
|
||||
file=file,
|
||||
metadata=metadata,
|
||||
pdf_pass=self.pdf_pass,
|
||||
db_session=db_session,
|
||||
llm=self.image_analysis_llm,
|
||||
)
|
||||
documents.extend(new_docs)
|
||||
|
||||
if len(documents) >= self.batch_size:
|
||||
yield documents
|
||||
|
||||
documents = []
|
||||
|
||||
if documents:
|
||||
@@ -201,7 +301,7 @@ class LocalFileConnector(LoadConnector):
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])
|
||||
connector.load_credentials({"pdf_password": os.environ["PDF_PASSWORD"]})
|
||||
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
connector.load_credentials({"pdf_password": os.environ.get("PDF_PASSWORD")})
|
||||
doc_batches = connector.load_from_state()
|
||||
for batch in doc_batches:
|
||||
print("BATCH:", batch)
|
||||
|
||||
@@ -20,7 +20,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -284,7 +284,7 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
user.get_repos().totalCount # Just check if we can access repos
|
||||
|
||||
except RateLimitExceededException:
|
||||
raise UnexpectedError(
|
||||
raise UnexpectedValidationError(
|
||||
"Validation failed due to GitHub rate-limits being exceeded. Please try again later."
|
||||
)
|
||||
|
||||
|
||||
@@ -4,14 +4,12 @@ from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
@@ -36,7 +34,6 @@ from onyx.connectors.google_utils.shared_constants import (
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_STR
|
||||
from onyx.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS
|
||||
from onyx.connectors.google_utils.shared_constants import SCOPE_DOC_URL
|
||||
from onyx.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE
|
||||
from onyx.connectors.google_utils.shared_constants import USER_FIELDS
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
@@ -46,7 +43,9 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
@@ -66,7 +65,10 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
|
||||
|
||||
|
||||
def _convert_single_file(
|
||||
creds: Any, primary_admin_email: str, file: dict[str, Any]
|
||||
creds: Any,
|
||||
primary_admin_email: str,
|
||||
file: dict[str, Any],
|
||||
image_analysis_llm: LLM | None,
|
||||
) -> Any:
|
||||
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
|
||||
user_drive_service = get_drive_service(creds, user_email=user_email)
|
||||
@@ -75,11 +77,14 @@ def _convert_single_file(
|
||||
file=file,
|
||||
drive_service=user_drive_service,
|
||||
docs_service=docs_service,
|
||||
image_analysis_llm=image_analysis_llm, # pass the LLM so doc_conversion can summarize images
|
||||
)
|
||||
|
||||
|
||||
def _process_files_batch(
|
||||
files: list[GoogleDriveFileType], convert_func: Callable, batch_size: int
|
||||
files: list[GoogleDriveFileType],
|
||||
convert_func: Callable[[GoogleDriveFileType], Any],
|
||||
batch_size: int,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch = []
|
||||
with ThreadPoolExecutor(max_workers=min(16, len(files))) as executor:
|
||||
@@ -111,7 +116,9 @@ def _clean_requested_drive_ids(
|
||||
return valid_requested_drive_ids, filtered_folder_ids
|
||||
|
||||
|
||||
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class GoogleDriveConnector(
|
||||
LoadConnector, PollConnector, SlimConnector, VisionEnabledConnector
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
include_shared_drives: bool = False,
|
||||
@@ -129,23 +136,23 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
continue_on_failure: bool | None = None,
|
||||
) -> None:
|
||||
# Check for old input parameters
|
||||
if (
|
||||
folder_paths is not None
|
||||
or include_shared is not None
|
||||
or follow_shortcuts is not None
|
||||
or only_org_public is not None
|
||||
or continue_on_failure is not None
|
||||
):
|
||||
logger.exception(
|
||||
"Google Drive connector received old input parameters. "
|
||||
"Please visit the docs for help with the new setup: "
|
||||
f"{SCOPE_DOC_URL}"
|
||||
if folder_paths is not None:
|
||||
logger.warning(
|
||||
"The 'folder_paths' parameter is deprecated. Use 'shared_folder_urls' instead."
|
||||
)
|
||||
raise ConnectorValidationError(
|
||||
"Google Drive connector received old input parameters. "
|
||||
"Please visit the docs for help with the new setup: "
|
||||
f"{SCOPE_DOC_URL}"
|
||||
if include_shared is not None:
|
||||
logger.warning(
|
||||
"The 'include_shared' parameter is deprecated. Use 'include_files_shared_with_me' instead."
|
||||
)
|
||||
if follow_shortcuts is not None:
|
||||
logger.warning("The 'follow_shortcuts' parameter is deprecated.")
|
||||
if only_org_public is not None:
|
||||
logger.warning("The 'only_org_public' parameter is deprecated.")
|
||||
if continue_on_failure is not None:
|
||||
logger.warning("The 'continue_on_failure' parameter is deprecated.")
|
||||
|
||||
# Initialize vision LLM using the mixin
|
||||
self.initialize_vision_llm()
|
||||
|
||||
if (
|
||||
not include_shared_drives
|
||||
@@ -237,6 +244,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
credentials=credentials,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
)
|
||||
|
||||
return new_creds_dict
|
||||
|
||||
def _update_traversed_parent_ids(self, folder_id: str) -> None:
|
||||
@@ -523,37 +531,53 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
# Create a larger process pool for file conversion
|
||||
convert_func = partial(
|
||||
_convert_single_file, self.creds, self.primary_admin_email
|
||||
)
|
||||
|
||||
# Process files in larger batches
|
||||
LARGE_BATCH_SIZE = self.batch_size * 4
|
||||
files_to_process = []
|
||||
# Gather the files into batches to be processed in parallel
|
||||
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
|
||||
if (
|
||||
file.get("size")
|
||||
and int(cast(str, file.get("size"))) > MAX_FILE_SIZE_BYTES
|
||||
):
|
||||
logger.warning(
|
||||
f"Skipping file {file.get('name', 'Unknown')} as it is too large: {file.get('size')} bytes"
|
||||
)
|
||||
continue
|
||||
|
||||
files_to_process.append(file)
|
||||
if len(files_to_process) >= LARGE_BATCH_SIZE:
|
||||
yield from _process_files_batch(
|
||||
files_to_process, convert_func, self.batch_size
|
||||
)
|
||||
files_to_process = []
|
||||
|
||||
# Process any remaining files
|
||||
if files_to_process:
|
||||
yield from _process_files_batch(
|
||||
files_to_process, convert_func, self.batch_size
|
||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
||||
# Prepare a partial function with the credentials and admin email
|
||||
convert_func = partial(
|
||||
_convert_single_file,
|
||||
self.creds,
|
||||
self.primary_admin_email,
|
||||
image_analysis_llm=self.image_analysis_llm, # Use the mixin's LLM
|
||||
)
|
||||
|
||||
# Fetch files in batches
|
||||
files_batch: list[GoogleDriveFileType] = []
|
||||
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
|
||||
files_batch.append(file)
|
||||
|
||||
if len(files_batch) >= self.batch_size:
|
||||
# Process the batch
|
||||
futures = [
|
||||
executor.submit(convert_func, file) for file in files_batch
|
||||
]
|
||||
documents = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
doc = future.result()
|
||||
if doc is not None:
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting file: {e}")
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
files_batch = []
|
||||
|
||||
# Process any remaining files
|
||||
if files_batch:
|
||||
futures = [executor.submit(convert_func, file) for file in files_batch]
|
||||
documents = []
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
doc = future.result()
|
||||
if doc is not None:
|
||||
documents.append(doc)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting file: {e}")
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
try:
|
||||
yield from self._extract_docs_from_google_drive()
|
||||
|
||||
@@ -9,7 +9,7 @@ from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import IGNORE_FOR_QA
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
|
||||
from onyx.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
@@ -21,32 +21,88 @@ from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.file_processing.extract_file_text import docx_to_text
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import docx_to_text_and_images
|
||||
from onyx.file_processing.extract_file_text import pptx_to_text
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.file_validation import is_valid_image_type
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.file_processing.unstructured import get_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import unstructured_to_text
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# these errors don't represent a failure in the connector, but simply files
|
||||
# that can't / shouldn't be indexed
|
||||
ERRORS_TO_CONTINUE_ON = [
|
||||
"cannotExportFile",
|
||||
"exportSizeLimitExceeded",
|
||||
"cannotDownloadFile",
|
||||
]
|
||||
def _summarize_drive_image(
|
||||
image_data: bytes, image_name: str, image_analysis_llm: LLM | None
|
||||
) -> str:
|
||||
"""
|
||||
Summarize the given image using the provided LLM.
|
||||
"""
|
||||
if not image_analysis_llm:
|
||||
return ""
|
||||
|
||||
return (
|
||||
summarize_image_with_error_handling(
|
||||
llm=image_analysis_llm,
|
||||
image_data=image_data,
|
||||
context_name=image_name,
|
||||
)
|
||||
or ""
|
||||
)
|
||||
|
||||
|
||||
def is_gdrive_image_mime_type(mime_type: str) -> bool:
|
||||
"""
|
||||
Return True if the mime_type is a common image type in GDrive.
|
||||
(e.g. 'image/png', 'image/jpeg')
|
||||
"""
|
||||
return is_valid_image_type(mime_type)
|
||||
|
||||
|
||||
def _extract_sections_basic(
|
||||
file: dict[str, str], service: GoogleDriveService
|
||||
file: dict[str, str],
|
||||
service: GoogleDriveService,
|
||||
image_analysis_llm: LLM | None = None,
|
||||
) -> list[Section]:
|
||||
"""
|
||||
Extends the existing logic to handle either a docx with embedded images
|
||||
or standalone images (PNG, JPG, etc).
|
||||
"""
|
||||
mime_type = file["mimeType"]
|
||||
link = file["webViewLink"]
|
||||
file_name = file.get("name", file["id"])
|
||||
supported_file_types = set(item.value for item in GDriveMimeType)
|
||||
|
||||
# 1) If the file is an image, retrieve the raw bytes, optionally summarize
|
||||
if is_gdrive_image_mime_type(mime_type):
|
||||
try:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, _ = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=response,
|
||||
file_name=file["id"],
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
llm=image_analysis_llm,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
return [section]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch or summarize image: {e}")
|
||||
return [
|
||||
Section(
|
||||
link=link,
|
||||
text="",
|
||||
image_file_name=link,
|
||||
)
|
||||
]
|
||||
|
||||
if mime_type not in supported_file_types:
|
||||
# Unsupported file types can still have a title, finding this way is still useful
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
@@ -185,45 +241,63 @@ def _extract_sections_basic(
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
]:
|
||||
return [
|
||||
Section(
|
||||
link=link,
|
||||
text=service.files()
|
||||
.get_media(fileId=file["id"])
|
||||
.execute()
|
||||
.decode("utf-8"),
|
||||
)
|
||||
]
|
||||
text_data = (
|
||||
service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
|
||||
)
|
||||
return [Section(link=link, text=text_data)]
|
||||
|
||||
# ---------------------------
|
||||
# Word, PowerPoint, PDF files
|
||||
if mime_type in [
|
||||
elif mime_type in [
|
||||
GDriveMimeType.WORD_DOC.value,
|
||||
GDriveMimeType.POWERPOINT.value,
|
||||
GDriveMimeType.PDF.value,
|
||||
]:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
response_bytes = service.files().get_media(fileId=file["id"]).execute()
|
||||
|
||||
# Optionally use Unstructured
|
||||
if get_unstructured_api_key():
|
||||
return [
|
||||
Section(
|
||||
link=link,
|
||||
text=unstructured_to_text(
|
||||
file=io.BytesIO(response),
|
||||
file_name=file.get("name", file["id"]),
|
||||
),
|
||||
)
|
||||
]
|
||||
text = unstructured_to_text(
|
||||
file=io.BytesIO(response_bytes),
|
||||
file_name=file_name,
|
||||
)
|
||||
return [Section(link=link, text=text)]
|
||||
|
||||
if mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
return [
|
||||
Section(link=link, text=docx_to_text(file=io.BytesIO(response)))
|
||||
]
|
||||
# Use docx_to_text_and_images to get text plus embedded images
|
||||
text, embedded_images = docx_to_text_and_images(
|
||||
file=io.BytesIO(response_bytes),
|
||||
)
|
||||
sections = []
|
||||
if text.strip():
|
||||
sections.append(Section(link=link, text=text.strip()))
|
||||
|
||||
# Process each embedded image using the standardized function
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for idx, (img_data, img_name) in enumerate(
|
||||
embedded_images, start=1
|
||||
):
|
||||
# Create a unique identifier for the embedded image
|
||||
embedded_id = f"{file['id']}_embedded_{idx}"
|
||||
|
||||
section, _ = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=img_data,
|
||||
file_name=embedded_id,
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
llm=image_analysis_llm,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
sections.append(section)
|
||||
return sections
|
||||
|
||||
elif mime_type == GDriveMimeType.PDF.value:
|
||||
text, _ = read_pdf_file(file=io.BytesIO(response))
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_bytes))
|
||||
return [Section(link=link, text=text)]
|
||||
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
return [
|
||||
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
|
||||
]
|
||||
text_data = pptx_to_text(io.BytesIO(response_bytes))
|
||||
return [Section(link=link, text=text_data)]
|
||||
|
||||
# Catch-all case, should not happen since there should be specific handling
|
||||
# for each of the supported file types
|
||||
@@ -231,7 +305,8 @@ def _extract_sections_basic(
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.exception(f"Error extracting sections from file: {e}")
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
|
||||
@@ -239,74 +314,62 @@ def convert_drive_item_to_document(
|
||||
file: GoogleDriveFileType,
|
||||
drive_service: GoogleDriveService,
|
||||
docs_service: GoogleDocsService,
|
||||
image_analysis_llm: LLM | None,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Main entry point for converting a Google Drive file => Document object.
|
||||
Now we accept an optional `llm` to pass to `_extract_sections_basic`.
|
||||
"""
|
||||
try:
|
||||
# Skip files that are shortcuts
|
||||
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
|
||||
logger.info("Ignoring Drive Shortcut Filetype")
|
||||
return None
|
||||
# Skip files that are folders
|
||||
if file.get("mimeType") == DRIVE_FOLDER_TYPE:
|
||||
logger.info("Ignoring Drive Folder Filetype")
|
||||
# skip shortcuts or folders
|
||||
if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]:
|
||||
logger.info("Skipping shortcut/folder.")
|
||||
return None
|
||||
|
||||
# If it's a Google Doc, we might do advanced parsing
|
||||
sections: list[Section] = []
|
||||
|
||||
# Special handling for Google Docs to preserve structure, link
|
||||
# to headers
|
||||
if file.get("mimeType") == GDriveMimeType.DOC.value:
|
||||
try:
|
||||
# get_document_sections is the advanced approach for Google Docs
|
||||
sections = get_document_sections(docs_service, file["id"])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Ran into exception '{e}' when pulling sections from Google Doc '{file['name']}'."
|
||||
" Falling back to basic extraction."
|
||||
f"Failed to pull google doc sections from '{file['name']}': {e}. "
|
||||
"Falling back to basic extraction."
|
||||
)
|
||||
# NOTE: this will run for either (1) the above failed or (2) the file is not a Google Doc
|
||||
|
||||
# If not a doc, or if we failed above, do our 'basic' approach
|
||||
if not sections:
|
||||
try:
|
||||
# For all other file types just extract the text
|
||||
sections = _extract_sections_basic(file, drive_service)
|
||||
sections = _extract_sections_basic(file, drive_service, image_analysis_llm)
|
||||
|
||||
except HttpError as e:
|
||||
reason = e.error_details[0]["reason"] if e.error_details else e.reason
|
||||
message = e.error_details[0]["message"] if e.error_details else e.reason
|
||||
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
|
||||
logger.warning(
|
||||
f"Could not export file '{file['name']}' due to '{message}', skipping..."
|
||||
)
|
||||
return None
|
||||
|
||||
raise
|
||||
if not sections:
|
||||
return None
|
||||
|
||||
doc_id = file["webViewLink"]
|
||||
updated_time = datetime.fromisoformat(file["modifiedTime"]).astimezone(
|
||||
timezone.utc
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=file["webViewLink"],
|
||||
id=doc_id,
|
||||
sections=sections,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file["name"],
|
||||
doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone(
|
||||
timezone.utc
|
||||
),
|
||||
metadata={}
|
||||
if any(section.text for section in sections)
|
||||
else {IGNORE_FOR_QA: "True"},
|
||||
doc_updated_at=updated_time,
|
||||
metadata={}, # or any metadata from 'file'
|
||||
additional_info=file.get("id"),
|
||||
)
|
||||
except Exception as e:
|
||||
if not CONTINUE_ON_CONNECTOR_FAILURE:
|
||||
raise e
|
||||
|
||||
logger.exception("Ran into exception when pulling a file from Google Drive")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error converting file '{file.get('name')}' to Document: {e}")
|
||||
if not CONTINUE_ON_CONNECTOR_FAILURE:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
|
||||
# Skip files that are folders or shortcuts
|
||||
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
|
||||
return None
|
||||
|
||||
return SlimDocument(
|
||||
id=file["webViewLink"],
|
||||
perm_sync_data={
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -111,6 +114,69 @@ class OAuthConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
T = TypeVar("T", bound="CredentialsProviderInterface")
|
||||
|
||||
|
||||
class CredentialsProviderInterface(abc.ABC, Generic[T]):
|
||||
@abc.abstractmethod
|
||||
def __enter__(self) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_tenant_id(self) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_provider_key(self) -> str:
|
||||
"""a unique key that the connector can use to lock around a credential
|
||||
that might be used simultaneously.
|
||||
|
||||
Will typically be the credential id, but can also just be something random
|
||||
in cases when there is nothing to lock (aka static credentials)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_dynamic(self) -> bool:
|
||||
"""If dynamic, the credentials may change during usage ... maening the client
|
||||
needs to use the locking features of the credentials provider to operate
|
||||
correctly.
|
||||
|
||||
If static, the client can simply reference the credentials once and use them
|
||||
through the entire indexing run.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CredentialsConnector(BaseConnector):
|
||||
"""Implement this if the connector needs to be able to read and write credentials
|
||||
on the fly. Typically used with shared credentials/tokens that might be renewed
|
||||
at any time."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_credentials_provider(
|
||||
self, credentials_provider: CredentialsProviderInterface
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Event driven
|
||||
class EventConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -28,7 +28,8 @@ class ConnectorMissingCredentialError(PermissionError):
|
||||
|
||||
class Section(BaseModel):
|
||||
text: str
|
||||
link: str | None
|
||||
link: str | None = None
|
||||
image_file_name: str | None = None
|
||||
|
||||
|
||||
class BasicExpertInfo(BaseModel):
|
||||
|
||||
@@ -19,7 +19,7 @@ from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -671,12 +671,12 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
"Please try again later."
|
||||
)
|
||||
else:
|
||||
raise UnexpectedError(
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected Notion HTTP error (status={status_code}): {http_err}"
|
||||
) from http_err
|
||||
|
||||
except Exception as exc:
|
||||
raise UnexpectedError(
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected error during Notion settings validation: {exc}"
|
||||
)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
@@ -702,7 +702,9 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
raise CredentialExpiredError(
|
||||
f"Invalid or expired Slack bot token ({error_msg})."
|
||||
)
|
||||
raise UnexpectedError(f"Slack API returned a failure: {error_msg}")
|
||||
raise UnexpectedValidationError(
|
||||
f"Slack API returned a failure: {error_msg}"
|
||||
)
|
||||
|
||||
# 3) If channels are specified, verify each is accessible
|
||||
if self.channels:
|
||||
@@ -740,13 +742,13 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
raise CredentialExpiredError(
|
||||
f"Invalid or expired Slack bot token ({slack_error})."
|
||||
)
|
||||
raise UnexpectedError(
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected Slack error '{slack_error}' during settings validation."
|
||||
)
|
||||
except ConnectorValidationError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise UnexpectedError(
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected error during Slack settings validation: {e}"
|
||||
)
|
||||
|
||||
|
||||
@@ -72,6 +72,7 @@ def make_slack_api_rate_limited(
|
||||
@wraps(call)
|
||||
def rate_limited_call(**kwargs: Any) -> SlackResponse:
|
||||
last_exception = None
|
||||
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
# Make the API call
|
||||
|
||||
@@ -16,7 +16,7 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_t
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -302,7 +302,7 @@ class TeamsConnector(LoadConnector, PollConnector):
|
||||
raise InsufficientPermissionsError(
|
||||
"Your app lacks sufficient permissions to read Teams (403 Forbidden)."
|
||||
)
|
||||
raise UnexpectedError(f"Unexpected error retrieving teams: {e}")
|
||||
raise UnexpectedValidationError(f"Unexpected error retrieving teams: {e}")
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
|
||||
45
backend/onyx/connectors/vision_enabled_connector.py
Normal file
45
backend/onyx/connectors/vision_enabled_connector.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Mixin for connectors that need vision capabilities.
|
||||
"""
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class VisionEnabledConnector:
|
||||
"""
|
||||
Mixin for connectors that need vision capabilities.
|
||||
|
||||
This mixin provides a standard way to initialize a vision-capable LLM
|
||||
for image analysis during indexing.
|
||||
|
||||
Usage:
|
||||
class MyConnector(LoadConnector, VisionEnabledConnector):
|
||||
def __init__(self, ...):
|
||||
super().__init__(...)
|
||||
self.initialize_vision_llm()
|
||||
"""
|
||||
|
||||
def initialize_vision_llm(self) -> None:
|
||||
"""
|
||||
Initialize a vision-capable LLM if enabled by configuration.
|
||||
|
||||
Sets self.image_analysis_llm to the LLM instance or None if disabled.
|
||||
"""
|
||||
self.image_analysis_llm: LLM | None = None
|
||||
if get_image_extraction_and_analysis_enabled():
|
||||
try:
|
||||
self.image_analysis_llm = get_default_llm_with_vision()
|
||||
if self.image_analysis_llm is None:
|
||||
logger.warning(
|
||||
"No LLM with vision found; image summarization will be disabled"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to initialize vision LLM due to an error: {str(e)}. "
|
||||
"Image summarization will be disabled."
|
||||
)
|
||||
self.image_analysis_llm = None
|
||||
@@ -28,7 +28,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
@@ -42,6 +42,10 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
|
||||
# Threshold for determining when to replace vs append iframe content
|
||||
IFRAME_TEXT_LENGTH_THRESHOLD = 700
|
||||
# Message indicating JavaScript is disabled, which often appears when scraping fails
|
||||
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
|
||||
|
||||
|
||||
class WEB_CONNECTOR_VALID_SETTINGS(str, Enum):
|
||||
@@ -138,7 +142,8 @@ def get_internal_links(
|
||||
# Account for malformed backslashes in URLs
|
||||
href = href.replace("\\", "/")
|
||||
|
||||
if should_ignore_pound and "#" in href:
|
||||
# "#!" indicates the page is using a hashbang URL, which is a client-side routing technique
|
||||
if should_ignore_pound and "#" in href and "#!" not in href:
|
||||
href = href.split("#")[0]
|
||||
|
||||
if not is_valid_url(href):
|
||||
@@ -152,6 +157,7 @@ def get_internal_links(
|
||||
|
||||
def start_playwright() -> Tuple[Playwright, BrowserContext]:
|
||||
playwright = sync_playwright().start()
|
||||
|
||||
browser = playwright.chromium.launch(headless=True)
|
||||
|
||||
context = browser.new_context()
|
||||
@@ -288,6 +294,7 @@ class WebConnector(LoadConnector):
|
||||
and converts them into documents"""
|
||||
visited_links: set[str] = set()
|
||||
to_visit: list[str] = self.to_visit_list
|
||||
content_hashes = set()
|
||||
|
||||
if not to_visit:
|
||||
raise ValueError("No URLs to visit")
|
||||
@@ -302,40 +309,41 @@ class WebConnector(LoadConnector):
|
||||
playwright, context = start_playwright()
|
||||
restart_playwright = False
|
||||
while to_visit:
|
||||
current_url = to_visit.pop()
|
||||
if current_url in visited_links:
|
||||
initial_url = to_visit.pop()
|
||||
if initial_url in visited_links:
|
||||
continue
|
||||
visited_links.add(current_url)
|
||||
visited_links.add(initial_url)
|
||||
|
||||
try:
|
||||
protected_url_check(current_url)
|
||||
protected_url_check(initial_url)
|
||||
except Exception as e:
|
||||
last_error = f"Invalid URL {current_url} due to {e}"
|
||||
last_error = f"Invalid URL {initial_url} due to {e}"
|
||||
logger.warning(last_error)
|
||||
continue
|
||||
|
||||
logger.info(f"Visiting {current_url}")
|
||||
index = len(visited_links)
|
||||
logger.info(f"{index}: Visiting {initial_url}")
|
||||
|
||||
try:
|
||||
check_internet_connection(current_url)
|
||||
check_internet_connection(initial_url)
|
||||
if restart_playwright:
|
||||
playwright, context = start_playwright()
|
||||
restart_playwright = False
|
||||
|
||||
if current_url.split(".")[-1] == "pdf":
|
||||
if initial_url.split(".")[-1] == "pdf":
|
||||
# PDF files are not checked for links
|
||||
response = requests.get(current_url)
|
||||
page_text, metadata = read_pdf_file(
|
||||
response = requests.get(initial_url)
|
||||
page_text, metadata, images = read_pdf_file(
|
||||
file=io.BytesIO(response.content)
|
||||
)
|
||||
last_modified = response.headers.get("Last-Modified")
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=current_url,
|
||||
sections=[Section(link=current_url, text=page_text)],
|
||||
id=initial_url,
|
||||
sections=[Section(link=initial_url, text=page_text)],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=current_url.split("/")[-1],
|
||||
semantic_identifier=initial_url.split("/")[-1],
|
||||
metadata=metadata,
|
||||
doc_updated_at=_get_datetime_from_last_modified_header(
|
||||
last_modified
|
||||
@@ -347,21 +355,29 @@ class WebConnector(LoadConnector):
|
||||
continue
|
||||
|
||||
page = context.new_page()
|
||||
page_response = page.goto(current_url)
|
||||
|
||||
# Can't use wait_until="networkidle" because it interferes with the scrolling behavior
|
||||
page_response = page.goto(
|
||||
initial_url,
|
||||
timeout=30000, # 30 seconds
|
||||
)
|
||||
|
||||
last_modified = (
|
||||
page_response.header_value("Last-Modified")
|
||||
if page_response
|
||||
else None
|
||||
)
|
||||
final_page = page.url
|
||||
if final_page != current_url:
|
||||
logger.info(f"Redirected to {final_page}")
|
||||
protected_url_check(final_page)
|
||||
current_url = final_page
|
||||
if current_url in visited_links:
|
||||
logger.info("Redirected page already indexed")
|
||||
final_url = page.url
|
||||
if final_url != initial_url:
|
||||
protected_url_check(final_url)
|
||||
initial_url = final_url
|
||||
if initial_url in visited_links:
|
||||
logger.info(
|
||||
f"{index}: {initial_url} redirected to {final_url} - already indexed"
|
||||
)
|
||||
continue
|
||||
visited_links.add(current_url)
|
||||
logger.info(f"{index}: {initial_url} redirected to {final_url}")
|
||||
visited_links.add(initial_url)
|
||||
|
||||
if self.scroll_before_scraping:
|
||||
scroll_attempts = 0
|
||||
@@ -379,26 +395,58 @@ class WebConnector(LoadConnector):
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
|
||||
if self.recursive:
|
||||
internal_links = get_internal_links(base_url, current_url, soup)
|
||||
internal_links = get_internal_links(base_url, initial_url, soup)
|
||||
for link in internal_links:
|
||||
if link not in visited_links:
|
||||
to_visit.append(link)
|
||||
|
||||
if page_response and str(page_response.status)[0] in ("4", "5"):
|
||||
last_error = f"Skipped indexing {current_url} due to HTTP {page_response.status} response"
|
||||
last_error = f"Skipped indexing {initial_url} due to HTTP {page_response.status} response"
|
||||
logger.info(last_error)
|
||||
continue
|
||||
|
||||
parsed_html = web_html_cleanup(soup, self.mintlify_cleanup)
|
||||
|
||||
"""For websites containing iframes that need to be scraped,
|
||||
the code below can extract text from within these iframes.
|
||||
"""
|
||||
logger.debug(
|
||||
f"{index}: Length of cleaned text {len(parsed_html.cleaned_text)}"
|
||||
)
|
||||
if JAVASCRIPT_DISABLED_MESSAGE in parsed_html.cleaned_text:
|
||||
iframe_count = page.frame_locator("iframe").locator("html").count()
|
||||
if iframe_count > 0:
|
||||
iframe_texts = (
|
||||
page.frame_locator("iframe")
|
||||
.locator("html")
|
||||
.all_inner_texts()
|
||||
)
|
||||
document_text = "\n".join(iframe_texts)
|
||||
""" 700 is the threshold value for the length of the text extracted
|
||||
from the iframe based on the issue faced """
|
||||
if len(parsed_html.cleaned_text) < IFRAME_TEXT_LENGTH_THRESHOLD:
|
||||
parsed_html.cleaned_text = document_text
|
||||
else:
|
||||
parsed_html.cleaned_text += "\n" + document_text
|
||||
|
||||
# Sometimes pages with #! will serve duplicate content
|
||||
# There are also just other ways this can happen
|
||||
hashed_text = hash((parsed_html.title, parsed_html.cleaned_text))
|
||||
if hashed_text in content_hashes:
|
||||
logger.info(
|
||||
f"{index}: Skipping duplicate title + content for {initial_url}"
|
||||
)
|
||||
continue
|
||||
content_hashes.add(hashed_text)
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=current_url,
|
||||
id=initial_url,
|
||||
sections=[
|
||||
Section(link=current_url, text=parsed_html.cleaned_text)
|
||||
Section(link=initial_url, text=parsed_html.cleaned_text)
|
||||
],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=parsed_html.title or current_url,
|
||||
semantic_identifier=parsed_html.title or initial_url,
|
||||
metadata={},
|
||||
doc_updated_at=_get_datetime_from_last_modified_header(
|
||||
last_modified
|
||||
@@ -410,7 +458,7 @@ class WebConnector(LoadConnector):
|
||||
|
||||
page.close()
|
||||
except Exception as e:
|
||||
last_error = f"Failed to fetch '{current_url}': {e}"
|
||||
last_error = f"Failed to fetch '{initial_url}': {e}"
|
||||
logger.exception(last_error)
|
||||
playwright.stop()
|
||||
restart_playwright = True
|
||||
@@ -481,7 +529,9 @@ class WebConnector(LoadConnector):
|
||||
)
|
||||
else:
|
||||
# Could be a 5xx or another error, treat as unexpected
|
||||
raise UnexpectedError(f"Unexpected error validating '{test_url}': {e}")
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected error validating '{test_url}': {e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -76,6 +76,10 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
provider_type=search_settings.provider_type,
|
||||
index_name=search_settings.index_name,
|
||||
multipass_indexing=search_settings.multipass_indexing,
|
||||
embedding_precision=search_settings.embedding_precision,
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
# Whether switching to this model requires re-indexing
|
||||
background_reindex_enabled=search_settings.background_reindex_enabled,
|
||||
# Reranking Details
|
||||
rerank_model_name=search_settings.rerank_model_name,
|
||||
rerank_provider_type=search_settings.rerank_provider_type,
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import base64
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
import numpy
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.configs.llm_configs import get_search_time_image_analysis_enabled
|
||||
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
@@ -18,11 +23,15 @@ from onyx.context.search.models import MAX_METRICS_CONTENT
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RerankMetricsContainer
|
||||
from onyx.context.search.models import SearchQuery
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.document_index.document_index_utils import (
|
||||
translate_boost_count_to_multiplier,
|
||||
)
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.natural_language_processing.search_nlp_models import RerankingModel
|
||||
from onyx.prompts.image_analysis import IMAGE_ANALYSIS_SYSTEM_PROMPT
|
||||
from onyx.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
@@ -30,6 +39,124 @@ from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
def update_image_sections_with_query(
|
||||
sections: list[InferenceSection],
|
||||
query: str,
|
||||
llm: LLM,
|
||||
) -> None:
|
||||
"""
|
||||
For each chunk in each section that has an image URL, call an LLM to produce
|
||||
a new 'content' string that directly addresses the user's query about that image.
|
||||
This implementation uses parallel processing for efficiency.
|
||||
"""
|
||||
logger = setup_logger()
|
||||
logger.debug(f"Starting image section update with query: {query}")
|
||||
|
||||
chunks_with_images = []
|
||||
for section in sections:
|
||||
for chunk in section.chunks:
|
||||
if chunk.image_file_name:
|
||||
chunks_with_images.append(chunk)
|
||||
|
||||
if not chunks_with_images:
|
||||
logger.debug("No images to process in the sections")
|
||||
return # No images to process
|
||||
|
||||
logger.info(f"Found {len(chunks_with_images)} chunks with images to process")
|
||||
|
||||
def process_image_chunk(chunk: InferenceChunk) -> tuple[str, str]:
|
||||
try:
|
||||
logger.debug(
|
||||
f"Processing image chunk with ID: {chunk.unique_id}, image: {chunk.image_file_name}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
file_record = get_default_file_store(db_session).read_file(
|
||||
cast(str, chunk.image_file_name), mode="b"
|
||||
)
|
||||
if not file_record:
|
||||
logger.error(f"Image file not found: {chunk.image_file_name}")
|
||||
raise Exception("File not found")
|
||||
file_content = file_record.read()
|
||||
image_base64 = base64.b64encode(file_content).decode()
|
||||
logger.debug(
|
||||
f"Successfully loaded image data for {chunk.image_file_name}"
|
||||
)
|
||||
|
||||
messages: list[BaseMessage] = [
|
||||
SystemMessage(content=IMAGE_ANALYSIS_SYSTEM_PROMPT),
|
||||
HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
f"The user's question is: '{query}'. "
|
||||
"Please analyze the following image in that context:\n"
|
||||
),
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}",
|
||||
},
|
||||
},
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
raw_response = llm.invoke(messages)
|
||||
|
||||
answer_text = message_to_string(raw_response).strip()
|
||||
return (
|
||||
chunk.unique_id,
|
||||
answer_text if answer_text else "No relevant info found.",
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Error updating image section with query source image url: {chunk.image_file_name}"
|
||||
)
|
||||
return chunk.unique_id, "Error analyzing image."
|
||||
|
||||
image_processing_tasks = [
|
||||
FunctionCall(process_image_chunk, (chunk,)) for chunk in chunks_with_images
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Starting parallel processing of {len(image_processing_tasks)} image tasks"
|
||||
)
|
||||
image_processing_results = run_functions_in_parallel(image_processing_tasks)
|
||||
logger.info(
|
||||
f"Completed parallel processing with {len(image_processing_results)} results"
|
||||
)
|
||||
|
||||
# Create a mapping of chunk IDs to their processed content
|
||||
chunk_id_to_content = {}
|
||||
success_count = 0
|
||||
for task_id, result in image_processing_results.items():
|
||||
if result:
|
||||
chunk_id, content = result
|
||||
chunk_id_to_content[chunk_id] = content
|
||||
success_count += 1
|
||||
else:
|
||||
logger.error(f"Task {task_id} failed to return a valid result")
|
||||
|
||||
logger.info(
|
||||
f"Successfully processed {success_count}/{len(image_processing_results)} images"
|
||||
)
|
||||
|
||||
# Update the chunks with the processed content
|
||||
updated_count = 0
|
||||
for section in sections:
|
||||
for chunk in section.chunks:
|
||||
if chunk.unique_id in chunk_id_to_content:
|
||||
chunk.content = chunk_id_to_content[chunk.unique_id]
|
||||
updated_count += 1
|
||||
|
||||
logger.info(
|
||||
f"Updated content for {updated_count} chunks with image analysis results"
|
||||
)
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -286,6 +413,10 @@ def search_postprocessing(
|
||||
# NOTE: if we don't rerank, we can return the chunks immediately
|
||||
# since we know this is the final order.
|
||||
# This way the user experience isn't delayed by the LLM step
|
||||
if get_search_time_image_analysis_enabled():
|
||||
update_image_sections_with_query(
|
||||
retrieved_sections, search_query.query, llm
|
||||
)
|
||||
_log_top_section_links(search_query.search_type.value, retrieved_sections)
|
||||
yield retrieved_sections
|
||||
sections_yielded = True
|
||||
@@ -323,6 +454,13 @@ def search_postprocessing(
|
||||
)
|
||||
else:
|
||||
_log_top_section_links(search_query.search_type.value, reranked_sections)
|
||||
|
||||
# Add the image processing step here
|
||||
if get_search_time_image_analysis_enabled():
|
||||
update_image_sections_with_query(
|
||||
reranked_sections, search_query.query, llm
|
||||
)
|
||||
|
||||
yield reranked_sections
|
||||
|
||||
llm_selected_section_ids = (
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -11,6 +12,7 @@ from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import Row
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.exc import MultipleResultsFound
|
||||
@@ -168,7 +170,7 @@ def get_chat_sessions_by_user(
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
stmt = stmt.order_by(desc(ChatSession.time_updated))
|
||||
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
@@ -375,24 +377,33 @@ def delete_chat_session(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_chat_sessions_older_than(days_old: int, db_session: Session) -> None:
|
||||
def get_chat_sessions_older_than(
|
||||
days_old: int, db_session: Session
|
||||
) -> list[tuple[UUID | None, UUID]]:
|
||||
"""
|
||||
Retrieves chat sessions older than a specified number of days.
|
||||
|
||||
Args:
|
||||
days_old: The number of days to consider as "old".
|
||||
db_session: The database session.
|
||||
|
||||
Returns:
|
||||
A list of tuples, where each tuple contains the user_id (can be None) and the chat_session_id of an old chat session.
|
||||
"""
|
||||
|
||||
cutoff_time = datetime.utcnow() - timedelta(days=days_old)
|
||||
old_sessions = db_session.execute(
|
||||
old_sessions: Sequence[Row[Tuple[UUID | None, UUID]]] = db_session.execute(
|
||||
select(ChatSession.user_id, ChatSession.id).where(
|
||||
ChatSession.time_created < cutoff_time
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_id, session_id in old_sessions:
|
||||
try:
|
||||
delete_chat_session(
|
||||
user_id, session_id, db_session, include_deleted=True, hard_delete=True
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"delete_chat_session exceptioned. "
|
||||
f"user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
# convert old_sessions to a conventional list of tuples
|
||||
returned_sessions: list[tuple[UUID | None, UUID]] = [
|
||||
(user_id, session_id) for user_id, session_id in old_sessions
|
||||
]
|
||||
|
||||
return returned_sessions
|
||||
|
||||
|
||||
def get_chat_message(
|
||||
@@ -962,6 +973,7 @@ def translate_db_message_to_chat_message_detail(
|
||||
chat_message.sub_questions
|
||||
),
|
||||
refined_answer_improvement=chat_message.refined_answer_improvement,
|
||||
is_agentic=chat_message.is_agentic,
|
||||
error=chat_message.error,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,14 +3,13 @@ from typing import Optional
|
||||
from typing import Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import column
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import literal
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import union_all
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import ColumnClause
|
||||
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
@@ -26,127 +25,87 @@ def search_chat_sessions(
|
||||
include_onyxbot_flows: bool = False,
|
||||
) -> Tuple[List[ChatSession], bool]:
|
||||
"""
|
||||
Search for chat sessions based on the provided query.
|
||||
If no query is provided, returns recent chat sessions.
|
||||
Fast full-text search on ChatSession + ChatMessage using tsvectors.
|
||||
|
||||
Returns a tuple of (chat_sessions, has_more)
|
||||
If no query is provided, returns the most recent chat sessions.
|
||||
Otherwise, searches both chat messages and session descriptions.
|
||||
|
||||
Returns a tuple of (sessions, has_more) where has_more indicates if
|
||||
there are additional results beyond the requested page.
|
||||
"""
|
||||
offset = (page - 1) * page_size
|
||||
offset_val = (page - 1) * page_size
|
||||
|
||||
# If no search query, we use standard SQLAlchemy pagination
|
||||
# If no query, just return the most recent sessions
|
||||
if not query or not query.strip():
|
||||
stmt = select(ChatSession)
|
||||
if user_id:
|
||||
stmt = (
|
||||
select(ChatSession)
|
||||
.order_by(desc(ChatSession.time_created))
|
||||
.offset(offset_val)
|
||||
.limit(page_size + 1)
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(ChatSession.user_id == user_id)
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(ChatSession.deleted.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.offset(offset).limit(page_size + 1)
|
||||
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
|
||||
chat_sessions = result.scalars().all()
|
||||
sessions = result.scalars().all()
|
||||
|
||||
has_more = len(chat_sessions) > page_size
|
||||
has_more = len(sessions) > page_size
|
||||
if has_more:
|
||||
chat_sessions = chat_sessions[:page_size]
|
||||
sessions = sessions[:page_size]
|
||||
|
||||
return list(chat_sessions), has_more
|
||||
return list(sessions), has_more
|
||||
|
||||
words = query.lower().strip().split()
|
||||
# Otherwise, proceed with full-text search
|
||||
query = query.strip()
|
||||
|
||||
# Message mach subquery
|
||||
message_matches = []
|
||||
for word in words:
|
||||
word_like = f"%{word}%"
|
||||
message_match: Select = (
|
||||
select(ChatMessage.chat_session_id, literal(1.0).label("search_rank"))
|
||||
.join(ChatSession, ChatSession.id == ChatMessage.chat_session_id)
|
||||
.where(func.lower(ChatMessage.message).like(word_like))
|
||||
)
|
||||
|
||||
if user_id:
|
||||
message_match = message_match.where(ChatSession.user_id == user_id)
|
||||
|
||||
message_matches.append(message_match)
|
||||
|
||||
if message_matches:
|
||||
message_matches_query = union_all(*message_matches).alias("message_matches")
|
||||
else:
|
||||
return [], False
|
||||
|
||||
# Description matches
|
||||
description_match: Select = select(
|
||||
ChatSession.id.label("chat_session_id"), literal(0.5).label("search_rank")
|
||||
).where(func.lower(ChatSession.description).like(f"%{query.lower()}%"))
|
||||
|
||||
if user_id:
|
||||
description_match = description_match.where(ChatSession.user_id == user_id)
|
||||
base_conditions = []
|
||||
if user_id is not None:
|
||||
base_conditions.append(ChatSession.user_id == user_id)
|
||||
if not include_onyxbot_flows:
|
||||
description_match = description_match.where(ChatSession.onyxbot_flow.is_(False))
|
||||
base_conditions.append(ChatSession.onyxbot_flow.is_(False))
|
||||
if not include_deleted:
|
||||
description_match = description_match.where(ChatSession.deleted.is_(False))
|
||||
base_conditions.append(ChatSession.deleted.is_(False))
|
||||
|
||||
# Combine all match sources
|
||||
combined_matches = union_all(
|
||||
message_matches_query.select(), description_match
|
||||
).alias("combined_matches")
|
||||
message_tsv: ColumnClause = column("message_tsv")
|
||||
description_tsv: ColumnClause = column("description_tsv")
|
||||
|
||||
# Use CTE to group and get max rank
|
||||
session_ranks = (
|
||||
select(
|
||||
combined_matches.c.chat_session_id,
|
||||
func.max(combined_matches.c.search_rank).label("rank"),
|
||||
)
|
||||
.group_by(combined_matches.c.chat_session_id)
|
||||
.alias("session_ranks")
|
||||
ts_query = func.plainto_tsquery("english", query)
|
||||
|
||||
description_session_ids = (
|
||||
select(ChatSession.id)
|
||||
.where(*base_conditions)
|
||||
.where(description_tsv.op("@@")(ts_query))
|
||||
)
|
||||
|
||||
# Get ranked sessions with pagination
|
||||
ranked_query = (
|
||||
db_session.query(session_ranks.c.chat_session_id, session_ranks.c.rank)
|
||||
.order_by(desc(session_ranks.c.rank), session_ranks.c.chat_session_id)
|
||||
.offset(offset)
|
||||
message_session_ids = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(*base_conditions)
|
||||
.where(message_tsv.op("@@")(ts_query))
|
||||
)
|
||||
|
||||
combined_ids = description_session_ids.union(message_session_ids).alias(
|
||||
"combined_ids"
|
||||
)
|
||||
|
||||
final_stmt = (
|
||||
select(ChatSession)
|
||||
.join(combined_ids, ChatSession.id == combined_ids.c.id)
|
||||
.order_by(desc(ChatSession.time_created))
|
||||
.distinct()
|
||||
.offset(offset_val)
|
||||
.limit(page_size + 1)
|
||||
.options(joinedload(ChatSession.persona))
|
||||
)
|
||||
|
||||
result = ranked_query.all()
|
||||
session_objs = db_session.execute(final_stmt).scalars().all()
|
||||
|
||||
# Extract session IDs and ranks
|
||||
session_ids_with_ranks = {row.chat_session_id: row.rank for row in result}
|
||||
session_ids = list(session_ids_with_ranks.keys())
|
||||
|
||||
if not session_ids:
|
||||
return [], False
|
||||
|
||||
# Now, let's query the actual ChatSession objects using the IDs
|
||||
stmt = select(ChatSession).where(ChatSession.id.in_(session_ids))
|
||||
|
||||
if user_id:
|
||||
stmt = stmt.where(ChatSession.user_id == user_id)
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(ChatSession.deleted.is_(False))
|
||||
|
||||
# Full objects with eager loading
|
||||
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
# Sort based on above ranking
|
||||
chat_sessions = sorted(
|
||||
chat_sessions,
|
||||
key=lambda session: (
|
||||
-session_ids_with_ranks.get(session.id, 0), # Rank (higher first)
|
||||
session.time_created.timestamp() * -1, # Then by time (newest first)
|
||||
),
|
||||
)
|
||||
|
||||
has_more = len(chat_sessions) > page_size
|
||||
has_more = len(session_objs) > page_size
|
||||
if has_more:
|
||||
chat_sessions = chat_sessions[:page_size]
|
||||
session_objs = session_objs[:page_size]
|
||||
|
||||
return chat_sessions, has_more
|
||||
return list(session_objs), has_more
|
||||
|
||||
@@ -360,18 +360,13 @@ def backend_update_credential_json(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_credential(
|
||||
def _delete_credential_internal(
|
||||
credential: Credential,
|
||||
credential_id: int,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"Credential by provided id {credential_id} does not exist or does not belong to user"
|
||||
)
|
||||
|
||||
"""Internal utility function to handle the actual deletion of a credential"""
|
||||
associated_connectors = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter(ConnectorCredentialPair.credential_id == credential_id)
|
||||
@@ -416,6 +411,35 @@ def delete_credential(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_credential_for_user(
|
||||
credential_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""Delete a credential that belongs to a specific user"""
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"Credential by provided id {credential_id} does not exist or does not belong to user"
|
||||
)
|
||||
|
||||
_delete_credential_internal(credential, credential_id, db_session, force)
|
||||
|
||||
|
||||
def delete_credential(
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""Delete a credential regardless of ownership (admin function)"""
|
||||
credential = fetch_credential_by_id(credential_id, db_session)
|
||||
if credential is None:
|
||||
raise ValueError(f"Credential by provided id {credential_id} does not exist")
|
||||
|
||||
_delete_credential_internal(credential, credential_id, db_session, force)
|
||||
|
||||
|
||||
def create_initial_public_credential(db_session: Session) -> None:
|
||||
error_msg = (
|
||||
"DB is not in a valid initial state."
|
||||
|
||||
@@ -63,6 +63,9 @@ class IndexModelStatus(str, PyEnum):
|
||||
PRESENT = "PRESENT"
|
||||
FUTURE = "FUTURE"
|
||||
|
||||
def is_current(self) -> bool:
|
||||
return self == IndexModelStatus.PRESENT
|
||||
|
||||
|
||||
class ChatSessionSharedStatus(str, PyEnum):
|
||||
PUBLIC = "public"
|
||||
@@ -83,3 +86,11 @@ class AccessType(str, PyEnum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
SYNC = "sync"
|
||||
|
||||
|
||||
class EmbeddingPrecision(str, PyEnum):
|
||||
# matches vespa tensor type
|
||||
# only support float / bfloat16 for now, since there's not a
|
||||
# good reason to specify anything else
|
||||
BFLOAT16 = "bfloat16"
|
||||
FLOAT = "float"
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import validates
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
from uuid import UUID
|
||||
|
||||
@@ -25,6 +26,7 @@ from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
|
||||
from sqlalchemy import Sequence
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Text
|
||||
@@ -44,7 +46,13 @@ from onyx.configs.constants import DEFAULT_BOOST, MilestoneRecordType
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.enums import AccessType, IndexingMode, SyncType, SyncStatus
|
||||
from onyx.db.enums import (
|
||||
AccessType,
|
||||
EmbeddingPrecision,
|
||||
IndexingMode,
|
||||
SyncType,
|
||||
SyncStatus,
|
||||
)
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.constants import TokenRateLimitScope
|
||||
@@ -205,6 +213,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
|
||||
)
|
||||
|
||||
@validates("email")
|
||||
def validate_email(self, key: str, value: str) -> str:
|
||||
return value.lower() if value else value
|
||||
|
||||
@property
|
||||
def password_configured(self) -> bool:
|
||||
"""
|
||||
@@ -710,6 +722,23 @@ class SearchSettings(Base):
|
||||
ForeignKey("embedding_provider.provider_type"), nullable=True
|
||||
)
|
||||
|
||||
# Whether switching to this model should re-index all connectors in the background
|
||||
# if no re-index is needed, will be ignored. Only used during the switch-over process.
|
||||
background_reindex_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
# allows for quantization -> less memory usage for a small performance hit
|
||||
embedding_precision: Mapped[EmbeddingPrecision] = mapped_column(
|
||||
Enum(EmbeddingPrecision, native_enum=False)
|
||||
)
|
||||
|
||||
# can be used to reduce dimensionality of vectors and save memory with
|
||||
# a small performance hit. More details in the `Reducing embedding dimensions`
|
||||
# section here:
|
||||
# https://platform.openai.com/docs/guides/embeddings#embedding-models
|
||||
# If not specified, will just use the model_dim without any reduction.
|
||||
# NOTE: this is only currently available for OpenAI models
|
||||
reduced_dimension: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Mini and Large Chunks (large chunk also checks for model max context)
|
||||
multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
@@ -791,6 +820,12 @@ class SearchSettings(Base):
|
||||
self.multipass_indexing, self.model_name, self.provider_type
|
||||
)
|
||||
|
||||
@property
|
||||
def final_embedding_dim(self) -> int:
|
||||
if self.reduced_dimension:
|
||||
return self.reduced_dimension
|
||||
return self.model_dim
|
||||
|
||||
@staticmethod
|
||||
def can_use_large_chunks(
|
||||
multipass: bool, model_name: str, provider_type: EmbeddingProvider | None
|
||||
@@ -1755,6 +1790,7 @@ class ChannelConfig(TypedDict):
|
||||
channel_name: str | None # None for default channel config
|
||||
respond_tag_only: NotRequired[bool] # defaults to False
|
||||
respond_to_bots: NotRequired[bool] # defaults to False
|
||||
is_ephemeral: NotRequired[bool] # defaults to False
|
||||
respond_member_group_list: NotRequired[list[str]]
|
||||
answer_filters: NotRequired[list[AllowedAnswerFilters]]
|
||||
# If None then no follow up
|
||||
@@ -2269,6 +2305,10 @@ class UserTenantMapping(Base):
|
||||
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
|
||||
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
|
||||
|
||||
@validates("email")
|
||||
def validate_email(self, key: str, value: str) -> str:
|
||||
return value.lower() if value else value
|
||||
|
||||
|
||||
# This is a mapping from tenant IDs to anonymous user paths
|
||||
class TenantAnonymousUserPath(Base):
|
||||
|
||||
@@ -100,9 +100,14 @@ def _add_user_filters(
|
||||
.correlate(Persona)
|
||||
)
|
||||
else:
|
||||
where_clause |= Persona.is_public == True # noqa: E712
|
||||
where_clause &= Persona.is_visible == True # noqa: E712
|
||||
# Group the public persona conditions
|
||||
public_condition = (Persona.is_public == True) & ( # noqa: E712
|
||||
Persona.is_visible == True # noqa: E712
|
||||
)
|
||||
|
||||
where_clause |= public_condition
|
||||
where_clause |= Persona__User.user_id == user.id
|
||||
|
||||
where_clause |= Persona.user_id == user.id
|
||||
|
||||
return stmt.where(where_clause)
|
||||
@@ -204,13 +209,21 @@ def create_update_persona(
|
||||
if not all_prompt_ids:
|
||||
raise ValueError("No prompt IDs provided")
|
||||
|
||||
is_default_persona: bool | None = create_persona_request.is_default_persona
|
||||
# Default persona validation
|
||||
if create_persona_request.is_default_persona:
|
||||
if not create_persona_request.is_public:
|
||||
raise ValueError("Cannot make a default persona non public")
|
||||
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
raise ValueError("Only admins can make a default persona")
|
||||
if user:
|
||||
# Curators can edit default personas, but not make them
|
||||
if (
|
||||
user.role == UserRole.CURATOR
|
||||
or user.role == UserRole.GLOBAL_CURATOR
|
||||
):
|
||||
is_default_persona = None
|
||||
elif user.role != UserRole.ADMIN:
|
||||
raise ValueError("Only admins can make a default persona")
|
||||
|
||||
persona = upsert_persona(
|
||||
persona_id=persona_id,
|
||||
@@ -236,7 +249,7 @@ def create_update_persona(
|
||||
num_chunks=create_persona_request.num_chunks,
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
is_default_persona=create_persona_request.is_default_persona,
|
||||
is_default_persona=is_default_persona,
|
||||
)
|
||||
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
@@ -423,7 +436,7 @@ def upsert_persona(
|
||||
remove_image: bool | None = None,
|
||||
search_start_date: datetime | None = None,
|
||||
builtin_persona: bool = False,
|
||||
is_default_persona: bool = False,
|
||||
is_default_persona: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
@@ -518,7 +531,11 @@ def upsert_persona(
|
||||
existing_persona.is_visible = is_visible
|
||||
existing_persona.search_start_date = search_start_date
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.is_default_persona = is_default_persona
|
||||
existing_persona.is_default_persona = (
|
||||
is_default_persona
|
||||
if is_default_persona is not None
|
||||
else existing_persona.is_default_persona
|
||||
)
|
||||
# Do not delete any associations manually added unless
|
||||
# a new updated list is provided
|
||||
if document_sets is not None:
|
||||
@@ -570,7 +587,9 @@ def upsert_persona(
|
||||
display_priority=display_priority,
|
||||
is_visible=is_visible,
|
||||
search_start_date=search_start_date,
|
||||
is_default_persona=is_default_persona,
|
||||
is_default_persona=is_default_persona
|
||||
if is_default_persona is not None
|
||||
else False,
|
||||
labels=labels or [],
|
||||
)
|
||||
db_session.add(new_persona)
|
||||
|
||||
@@ -148,3 +148,28 @@ def upsert_pgfilestore(
|
||||
db_session.commit()
|
||||
|
||||
return pgfilestore
|
||||
|
||||
|
||||
def save_bytes_to_pgfilestore(
|
||||
db_session: Session,
|
||||
raw_bytes: bytes,
|
||||
media_type: str,
|
||||
identifier: str,
|
||||
display_name: str,
|
||||
file_origin: FileOrigin = FileOrigin.OTHER,
|
||||
) -> PGFileStore:
|
||||
"""
|
||||
Saves raw bytes to PGFileStore and returns the resulting record.
|
||||
"""
|
||||
file_name = f"{file_origin.name.lower()}_{identifier}"
|
||||
lobj_oid = create_populate_lobj(BytesIO(raw_bytes), db_session)
|
||||
pgfilestore = upsert_pgfilestore(
|
||||
file_name=file_name,
|
||||
display_name=display_name,
|
||||
file_origin=file_origin,
|
||||
file_type=media_type,
|
||||
lobj_oid=lobj_oid,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
return pgfilestore
|
||||
|
||||
@@ -14,6 +14,7 @@ from onyx.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from onyx.context.search.models import SavedSearchSettings
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.llm import fetch_embedding_provider
|
||||
from onyx.db.models import CloudEmbeddingProvider
|
||||
from onyx.db.models import IndexAttempt
|
||||
@@ -59,12 +60,15 @@ def create_search_settings(
|
||||
index_name=search_settings.index_name,
|
||||
provider_type=search_settings.provider_type,
|
||||
multipass_indexing=search_settings.multipass_indexing,
|
||||
embedding_precision=search_settings.embedding_precision,
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
multilingual_expansion=search_settings.multilingual_expansion,
|
||||
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
|
||||
rerank_model_name=search_settings.rerank_model_name,
|
||||
rerank_provider_type=search_settings.rerank_provider_type,
|
||||
rerank_api_key=search_settings.rerank_api_key,
|
||||
num_rerank=search_settings.num_rerank,
|
||||
background_reindex_enabled=search_settings.background_reindex_enabled,
|
||||
)
|
||||
|
||||
db_session.add(embedding_model)
|
||||
@@ -305,6 +309,7 @@ def get_old_default_embedding_model() -> IndexingSetting:
|
||||
model_dim=(
|
||||
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
),
|
||||
embedding_precision=(EmbeddingPrecision.FLOAT),
|
||||
normalize=(
|
||||
NORMALIZE_EMBEDDINGS
|
||||
if is_overridden
|
||||
@@ -322,6 +327,7 @@ def get_new_default_embedding_model() -> IndexingSetting:
|
||||
return IndexingSetting(
|
||||
model_name=DOCUMENT_ENCODER_MODEL,
|
||||
model_dim=DOC_EMBEDDING_DIM,
|
||||
embedding_precision=(EmbeddingPrecision.FLOAT),
|
||||
normalize=NORMALIZE_EMBEDDINGS,
|
||||
query_prefix=ASYM_QUERY_PREFIX,
|
||||
passage_prefix=ASYM_PASSAGE_PREFIX,
|
||||
|
||||
53
backend/onyx/db/seeding/chat_history_seeding.py
Normal file
53
backend/onyx/db/seeding/chat_history_seeding.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import random
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
|
||||
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
|
||||
"""Utility function to seed chat history for testing.
|
||||
|
||||
num_sessions: the number of sessions to seed
|
||||
num_messages: the number of messages to seed per sessions
|
||||
days: the number of days looking backwards from the current time over which to randomize
|
||||
the times.
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for y in range(0, num_sessions):
|
||||
create_chat_session(db_session, f"pytest_session_{y}", None, None)
|
||||
|
||||
# randomize all session times
|
||||
rows = db_session.query(ChatSession).all()
|
||||
for row in rows:
|
||||
row.time_created = datetime.utcnow() - timedelta(
|
||||
days=random.randint(0, days)
|
||||
)
|
||||
row.time_updated = row.time_created + timedelta(
|
||||
minutes=random.randint(0, 10)
|
||||
)
|
||||
|
||||
root_message = get_or_create_root_message(row.id, db_session)
|
||||
|
||||
for x in range(0, num_messages):
|
||||
chat_message = create_new_chat_message(
|
||||
row.id,
|
||||
root_message,
|
||||
f"pytest_message_{x}",
|
||||
None,
|
||||
0,
|
||||
MessageType.USER,
|
||||
db_session,
|
||||
)
|
||||
|
||||
chat_message.time_sent = row.time_created + timedelta(
|
||||
minutes=random.randint(0, 10)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
db_session.commit()
|
||||
@@ -8,10 +8,12 @@ from onyx.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from onyx.db.index_attempt import (
|
||||
count_unique_cc_pairs_with_successful_index_attempts,
|
||||
)
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.search_settings import update_search_settings_status
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -19,7 +21,49 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
def _perform_index_swap(
|
||||
db_session: Session,
|
||||
current_search_settings: SearchSettings,
|
||||
secondary_search_settings: SearchSettings,
|
||||
all_cc_pairs: list[ConnectorCredentialPair],
|
||||
) -> None:
|
||||
"""Swap the indices and expire the old one."""
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
update_search_settings_status(
|
||||
search_settings=current_search_settings,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_search_settings_status(
|
||||
search_settings=secondary_search_settings,
|
||||
new_status=IndexModelStatus.PRESENT,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if len(all_cc_pairs) > 0:
|
||||
kv_store = get_kv_store()
|
||||
kv_store.store(KV_REINDEX_KEY, False)
|
||||
|
||||
# Expire jobs for the now past index/embedding model
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
# Recount aggregates
|
||||
for cc_pair in all_cc_pairs:
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
# remove the old index from the vector db
|
||||
document_index = get_default_document_index(secondary_search_settings, None)
|
||||
document_index.ensure_indices_exist(
|
||||
primary_embedding_dim=secondary_search_settings.final_embedding_dim,
|
||||
primary_embedding_precision=secondary_search_settings.embedding_precision,
|
||||
# just finished swap, no more secondary index
|
||||
secondary_index_embedding_dim=None,
|
||||
secondary_index_embedding_precision=None,
|
||||
)
|
||||
|
||||
|
||||
def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
"""Get count of cc-pairs and count of successful index_attempts for the
|
||||
new model grouped by connector + credential, if it's the same, then assume
|
||||
new index is done building. If so, swap the indices and expire the old one.
|
||||
@@ -27,52 +71,45 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
Returns None if search settings did not change, or the old search settings if they
|
||||
did change.
|
||||
"""
|
||||
|
||||
old_search_settings = None
|
||||
|
||||
# Default CC-pair created for Ingestion API unused here
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
|
||||
search_settings = get_secondary_search_settings(db_session)
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
if not search_settings:
|
||||
if not secondary_search_settings:
|
||||
return None
|
||||
|
||||
# If the secondary search settings are not configured to reindex in the background,
|
||||
# we can just swap over instantly
|
||||
if not secondary_search_settings.background_reindex_enabled:
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
_perform_index_swap(
|
||||
db_session=db_session,
|
||||
current_search_settings=current_search_settings,
|
||||
secondary_search_settings=secondary_search_settings,
|
||||
all_cc_pairs=all_cc_pairs,
|
||||
)
|
||||
return current_search_settings
|
||||
|
||||
unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts(
|
||||
search_settings_id=search_settings.id, db_session=db_session
|
||||
search_settings_id=secondary_search_settings.id, db_session=db_session
|
||||
)
|
||||
|
||||
# Index Attempts are cleaned up as well when the cc-pair is deleted so the logic in this
|
||||
# function is correct. The unique_cc_indexings are specifically for the existing cc-pairs
|
||||
old_search_settings = None
|
||||
if unique_cc_indexings > cc_pair_count:
|
||||
logger.error("More unique indexings than cc pairs, should not occur")
|
||||
|
||||
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
|
||||
# Swap indices
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
update_search_settings_status(
|
||||
search_settings=current_search_settings,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
_perform_index_swap(
|
||||
db_session=db_session,
|
||||
current_search_settings=current_search_settings,
|
||||
secondary_search_settings=secondary_search_settings,
|
||||
all_cc_pairs=all_cc_pairs,
|
||||
)
|
||||
|
||||
update_search_settings_status(
|
||||
search_settings=search_settings,
|
||||
new_status=IndexModelStatus.PRESENT,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if cc_pair_count > 0:
|
||||
kv_store = get_kv_store()
|
||||
kv_store.store(KV_REINDEX_KEY, False)
|
||||
|
||||
# Expire jobs for the now past index/embedding model
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
# Recount aggregates
|
||||
for cc_pair in all_cc_pairs:
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
old_search_settings = current_search_settings
|
||||
old_search_settings = current_search_settings
|
||||
|
||||
return old_search_settings
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
@@ -145,17 +146,21 @@ class Verifiable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def ensure_indices_exist(
|
||||
self,
|
||||
index_embedding_dim: int,
|
||||
primary_embedding_dim: int,
|
||||
primary_embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_embedding_dim: int | None,
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None,
|
||||
) -> None:
|
||||
"""
|
||||
Verify that the document index exists and is consistent with the expectations in the code.
|
||||
|
||||
Parameters:
|
||||
- index_embedding_dim: Vector dimensionality for the vector similarity part of the search
|
||||
- primary_embedding_dim: Vector dimensionality for the vector similarity part of the search
|
||||
- primary_embedding_precision: Precision of the vector similarity part of the search
|
||||
- secondary_index_embedding_dim: Vector dimensionality of the secondary index being built
|
||||
behind the scenes. The secondary index should only be built when switching
|
||||
embedding models therefore this dim should be different from the primary index.
|
||||
- secondary_index_embedding_precision: Precision of the vector similarity part of the secondary index
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -164,6 +169,7 @@ class Verifiable(abc.ABC):
|
||||
def register_multitenant_indices(
|
||||
indices: list[str],
|
||||
embedding_dims: list[int],
|
||||
embedding_precisions: list[EmbeddingPrecision],
|
||||
) -> None:
|
||||
"""
|
||||
Register multitenant indices with the document index.
|
||||
|
||||
@@ -37,7 +37,7 @@ schema DANSWER_CHUNK_NAME {
|
||||
summary: dynamic
|
||||
}
|
||||
# Title embedding (x1)
|
||||
field title_embedding type tensor<float>(x[VARIABLE_DIM]) {
|
||||
field title_embedding type tensor<EMBEDDING_PRECISION>(x[VARIABLE_DIM]) {
|
||||
indexing: attribute | index
|
||||
attribute {
|
||||
distance-metric: angular
|
||||
@@ -45,7 +45,7 @@ schema DANSWER_CHUNK_NAME {
|
||||
}
|
||||
# Content embeddings (chunk + optional mini chunks embeddings)
|
||||
# "t" and "x" are arbitrary names, not special keywords
|
||||
field embeddings type tensor<float>(t{},x[VARIABLE_DIM]) {
|
||||
field embeddings type tensor<EMBEDDING_PRECISION>(t{},x[VARIABLE_DIM]) {
|
||||
indexing: attribute | index
|
||||
attribute {
|
||||
distance-metric: angular
|
||||
@@ -55,6 +55,9 @@ schema DANSWER_CHUNK_NAME {
|
||||
field blurb type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field image_file_name type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
# https://docs.vespa.ai/en/attributes.html potential enum store for speed, but probably not worth it
|
||||
field source_type type string {
|
||||
indexing: summary | attribute
|
||||
|
||||
@@ -5,4 +5,7 @@
|
||||
<allow
|
||||
until="DATE_REPLACEMENT"
|
||||
comment="We need to be able to update the schema for updates to the Onyx schema">indexing-change</allow>
|
||||
<allow
|
||||
until='DATE_REPLACEMENT'
|
||||
comment="Prevents old alt indices from interfering with changes">field-type-change</allow>
|
||||
</validation-overrides>
|
||||
|
||||
@@ -31,6 +31,7 @@ from onyx.document_index.vespa_constants import DOC_UPDATED_AT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
|
||||
from onyx.document_index.vespa_constants import LARGE_CHUNK_REFERENCE_IDS
|
||||
from onyx.document_index.vespa_constants import MAX_ID_SEARCH_QUERY_SIZE
|
||||
from onyx.document_index.vespa_constants import MAX_OR_CONDITIONS
|
||||
@@ -130,6 +131,7 @@ def _vespa_hit_to_inference_chunk(
|
||||
section_continuation=fields[SECTION_CONTINUATION],
|
||||
document_id=fields[DOCUMENT_ID],
|
||||
source_type=fields[SOURCE_TYPE],
|
||||
image_file_name=fields.get(IMAGE_FILE_NAME),
|
||||
title=fields.get(TITLE),
|
||||
semantic_identifier=fields[SEMANTIC_IDENTIFIER],
|
||||
boost=fields.get(BOOST, 1),
|
||||
@@ -211,6 +213,7 @@ def _get_chunks_via_visit_api(
|
||||
|
||||
# Check if the response contains any documents
|
||||
response_data = response.json()
|
||||
|
||||
if "documents" in response_data:
|
||||
for document in response_data["documents"]:
|
||||
if filters.access_control_list:
|
||||
@@ -310,6 +313,11 @@ def query_vespa(
|
||||
f"Request Headers: {e.request.headers}\n"
|
||||
f"Request Payload: {params}\n"
|
||||
f"Exception: {str(e)}"
|
||||
+ (
|
||||
f"\nResponse: {e.response.text}"
|
||||
if isinstance(e, httpx.HTTPStatusError)
|
||||
else ""
|
||||
)
|
||||
)
|
||||
raise httpx.HTTPError(error_base) from e
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from onyx.configs.chat_configs import VESPA_SEARCHER_THREADS
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.document_index_utils import get_document_chunk_ids
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
@@ -63,6 +64,7 @@ from onyx.document_index.vespa_constants import DATE_REPLACEMENT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_REPLACEMENT_PAT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from onyx.document_index.vespa_constants import EMBEDDING_PRECISION_REPLACEMENT_PAT
|
||||
from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import NUM_THREADS
|
||||
from onyx.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT
|
||||
@@ -112,6 +114,21 @@ def _create_document_xml_lines(doc_names: list[str | None] | list[str]) -> str:
|
||||
return "\n".join(doc_lines)
|
||||
|
||||
|
||||
def _replace_template_values_in_schema(
|
||||
schema_template: str,
|
||||
index_name: str,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
) -> str:
|
||||
return (
|
||||
schema_template.replace(
|
||||
EMBEDDING_PRECISION_REPLACEMENT_PAT, embedding_precision.value
|
||||
)
|
||||
.replace(DANSWER_CHUNK_REPLACEMENT_PAT, index_name)
|
||||
.replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim))
|
||||
)
|
||||
|
||||
|
||||
def add_ngrams_to_schema(schema_content: str) -> str:
|
||||
# Add the match blocks containing gram and gram-size to title and content fields
|
||||
schema_content = re.sub(
|
||||
@@ -163,8 +180,10 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
def ensure_indices_exist(
|
||||
self,
|
||||
index_embedding_dim: int,
|
||||
primary_embedding_dim: int,
|
||||
primary_embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_embedding_dim: int | None,
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None,
|
||||
) -> None:
|
||||
if MULTI_TENANT:
|
||||
logger.info(
|
||||
@@ -221,18 +240,29 @@ class VespaIndex(DocumentIndex):
|
||||
schema_template = schema_f.read()
|
||||
schema_template = schema_template.replace(TENANT_ID_PAT, "")
|
||||
|
||||
schema = schema_template.replace(
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT, self.index_name
|
||||
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
|
||||
schema = _replace_template_values_in_schema(
|
||||
schema_template,
|
||||
self.index_name,
|
||||
primary_embedding_dim,
|
||||
primary_embedding_precision,
|
||||
)
|
||||
|
||||
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
|
||||
schema = schema.replace(TENANT_ID_PAT, "")
|
||||
zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8")
|
||||
|
||||
if self.secondary_index_name:
|
||||
upcoming_schema = schema_template.replace(
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT, self.secondary_index_name
|
||||
).replace(VESPA_DIM_REPLACEMENT_PAT, str(secondary_index_embedding_dim))
|
||||
if secondary_index_embedding_dim is None:
|
||||
raise ValueError("Secondary index embedding dimension is required")
|
||||
if secondary_index_embedding_precision is None:
|
||||
raise ValueError("Secondary index embedding precision is required")
|
||||
|
||||
upcoming_schema = _replace_template_values_in_schema(
|
||||
schema_template,
|
||||
self.secondary_index_name,
|
||||
secondary_index_embedding_dim,
|
||||
secondary_index_embedding_precision,
|
||||
)
|
||||
zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8")
|
||||
|
||||
zip_file = in_memory_zip_from_file_bytes(zip_dict)
|
||||
@@ -251,6 +281,7 @@ class VespaIndex(DocumentIndex):
|
||||
def register_multitenant_indices(
|
||||
indices: list[str],
|
||||
embedding_dims: list[int],
|
||||
embedding_precisions: list[EmbeddingPrecision],
|
||||
) -> None:
|
||||
if not MULTI_TENANT:
|
||||
raise ValueError("Multi-tenant is not enabled")
|
||||
@@ -309,13 +340,14 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
for i, index_name in enumerate(indices):
|
||||
embedding_dim = embedding_dims[i]
|
||||
embedding_precision = embedding_precisions[i]
|
||||
logger.info(
|
||||
f"Creating index: {index_name} with embedding dimension: {embedding_dim}"
|
||||
)
|
||||
|
||||
schema = schema_template.replace(
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT, index_name
|
||||
).replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim))
|
||||
schema = _replace_template_values_in_schema(
|
||||
schema_template, index_name, embedding_dim, embedding_precision
|
||||
)
|
||||
schema = schema.replace(
|
||||
TENANT_ID_PAT, TENANT_ID_REPLACEMENT if MULTI_TENANT else ""
|
||||
)
|
||||
|
||||
@@ -32,6 +32,7 @@ from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from onyx.document_index.vespa_constants import EMBEDDINGS
|
||||
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
|
||||
from onyx.document_index.vespa_constants import LARGE_CHUNK_REFERENCE_IDS
|
||||
from onyx.document_index.vespa_constants import METADATA
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
@@ -198,13 +199,13 @@ def _index_vespa_chunk(
|
||||
# which only calls VespaIndex.update
|
||||
ACCESS_CONTROL_LIST: {acl_entry: 1 for acl_entry in chunk.access.to_acl()},
|
||||
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
|
||||
IMAGE_FILE_NAME: chunk.image_file_name,
|
||||
BOOST: chunk.boost,
|
||||
}
|
||||
|
||||
if multitenant:
|
||||
if chunk.tenant_id:
|
||||
vespa_document_fields[TENANT_ID] = chunk.tenant_id
|
||||
|
||||
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_chunk_id}"
|
||||
logger.debug(f'Indexing to URL "{vespa_url}"')
|
||||
res = http_client.post(
|
||||
|
||||
@@ -6,6 +6,7 @@ from onyx.configs.app_configs import VESPA_TENANT_PORT
|
||||
from onyx.configs.constants import SOURCE_TYPE
|
||||
|
||||
VESPA_DIM_REPLACEMENT_PAT = "VARIABLE_DIM"
|
||||
EMBEDDING_PRECISION_REPLACEMENT_PAT = "EMBEDDING_PRECISION"
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME"
|
||||
DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT"
|
||||
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
|
||||
@@ -76,6 +77,7 @@ PRIMARY_OWNERS = "primary_owners"
|
||||
SECONDARY_OWNERS = "secondary_owners"
|
||||
RECENCY_BIAS = "recency_bias"
|
||||
HIDDEN = "hidden"
|
||||
IMAGE_FILE_NAME = "image_file_name"
|
||||
|
||||
# Specific to Vespa, needed for highlighting matching keywords / section
|
||||
CONTENT_SUMMARY = "content_summary"
|
||||
@@ -93,6 +95,7 @@ YQL_BASE = (
|
||||
f"{SEMANTIC_IDENTIFIER}, "
|
||||
f"{TITLE}, "
|
||||
f"{SECTION_CONTINUATION}, "
|
||||
f"{IMAGE_FILE_NAME}, "
|
||||
f"{BOOST}, "
|
||||
f"{HIDDEN}, "
|
||||
f"{DOC_UPDATED_AT}, "
|
||||
|
||||
@@ -9,15 +9,17 @@ from email.parser import Parser as EmailParser
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import IO
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import chardet
|
||||
import docx # type: ignore
|
||||
import openpyxl # type: ignore
|
||||
import pptx # type: ignore
|
||||
from docx import Document
|
||||
from docx import Document as DocxDocument
|
||||
from fastapi import UploadFile
|
||||
from PIL import Image
|
||||
from pypdf import PdfReader
|
||||
from pypdf.errors import PdfStreamError
|
||||
|
||||
@@ -31,10 +33,8 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
TEXT_SECTION_SEPARATOR = "\n\n"
|
||||
|
||||
|
||||
PLAIN_TEXT_FILE_EXTENSIONS = [
|
||||
".txt",
|
||||
".md",
|
||||
@@ -49,7 +49,6 @@ PLAIN_TEXT_FILE_EXTENSIONS = [
|
||||
".yaml",
|
||||
]
|
||||
|
||||
|
||||
VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
|
||||
".pdf",
|
||||
".docx",
|
||||
@@ -58,6 +57,16 @@ VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
|
||||
".eml",
|
||||
".epub",
|
||||
".html",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".webp",
|
||||
]
|
||||
|
||||
IMAGE_MEDIA_TYPES = [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/webp",
|
||||
]
|
||||
|
||||
|
||||
@@ -67,11 +76,13 @@ def is_text_file_extension(file_name: str) -> bool:
|
||||
|
||||
def get_file_ext(file_path_or_name: str | Path) -> str:
|
||||
_, extension = os.path.splitext(file_path_or_name)
|
||||
# standardize all extensions to be lowercase so that checks against
|
||||
# VALID_FILE_EXTENSIONS and similar will work as intended
|
||||
return extension.lower()
|
||||
|
||||
|
||||
def is_valid_media_type(media_type: str) -> bool:
|
||||
return media_type in IMAGE_MEDIA_TYPES
|
||||
|
||||
|
||||
def is_valid_file_ext(ext: str) -> bool:
|
||||
return ext in VALID_FILE_EXTENSIONS
|
||||
|
||||
@@ -79,17 +90,18 @@ def is_valid_file_ext(ext: str) -> bool:
|
||||
def is_text_file(file: IO[bytes]) -> bool:
|
||||
"""
|
||||
checks if the first 1024 bytes only contain printable or whitespace characters
|
||||
if it does, then we say its a plaintext file
|
||||
if it does, then we say it's a plaintext file
|
||||
"""
|
||||
raw_data = file.read(1024)
|
||||
file.seek(0)
|
||||
text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F})
|
||||
return all(c in text_chars for c in raw_data)
|
||||
|
||||
|
||||
def detect_encoding(file: IO[bytes]) -> str:
|
||||
raw_data = file.read(50000)
|
||||
encoding = chardet.detect(raw_data)["encoding"] or "utf-8"
|
||||
file.seek(0)
|
||||
encoding = chardet.detect(raw_data)["encoding"] or "utf-8"
|
||||
return encoding
|
||||
|
||||
|
||||
@@ -99,14 +111,14 @@ def is_macos_resource_fork_file(file_name: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
# To include additional metadata in the search index, add a .onyx_metadata.json file
|
||||
# to the zip file. This file should contain a list of objects with the following format:
|
||||
# [{ "filename": "file1.txt", "link": "https://example.com/file1.txt" }]
|
||||
def load_files_from_zip(
|
||||
zip_file_io: IO,
|
||||
ignore_macos_resource_fork_files: bool = True,
|
||||
ignore_dirs: bool = True,
|
||||
) -> Iterator[tuple[zipfile.ZipInfo, IO[Any], dict[str, Any]]]:
|
||||
"""
|
||||
If there's a .onyx_metadata.json in the zip, attach those metadata to each subfile.
|
||||
"""
|
||||
with zipfile.ZipFile(zip_file_io, "r") as zip_file:
|
||||
zip_metadata = {}
|
||||
try:
|
||||
@@ -118,24 +130,31 @@ def load_files_from_zip(
|
||||
# convert list of dicts to dict of dicts
|
||||
zip_metadata = {d["filename"]: d for d in zip_metadata}
|
||||
except json.JSONDecodeError:
|
||||
logger.warn(f"Unable to load {DANSWER_METADATA_FILENAME}")
|
||||
logger.warning(f"Unable to load {DANSWER_METADATA_FILENAME}")
|
||||
except KeyError:
|
||||
logger.info(f"No {DANSWER_METADATA_FILENAME} file")
|
||||
|
||||
for file_info in zip_file.infolist():
|
||||
with zip_file.open(file_info.filename, "r") as file:
|
||||
if ignore_dirs and file_info.is_dir():
|
||||
continue
|
||||
if ignore_dirs and file_info.is_dir():
|
||||
continue
|
||||
|
||||
if (
|
||||
ignore_macos_resource_fork_files
|
||||
and is_macos_resource_fork_file(file_info.filename)
|
||||
) or file_info.filename == DANSWER_METADATA_FILENAME:
|
||||
continue
|
||||
yield file_info, file, zip_metadata.get(file_info.filename, {})
|
||||
if (
|
||||
ignore_macos_resource_fork_files
|
||||
and is_macos_resource_fork_file(file_info.filename)
|
||||
) or file_info.filename == DANSWER_METADATA_FILENAME:
|
||||
continue
|
||||
|
||||
with zip_file.open(file_info.filename, "r") as subfile:
|
||||
yield file_info, subfile, zip_metadata.get(file_info.filename, {})
|
||||
|
||||
|
||||
def _extract_onyx_metadata(line: str) -> dict | None:
|
||||
"""
|
||||
Example: first line has:
|
||||
<!-- DANSWER_METADATA={"title": "..."} -->
|
||||
or
|
||||
#DANSWER_METADATA={"title":"..."}
|
||||
"""
|
||||
html_comment_pattern = r"<!--\s*DANSWER_METADATA=\{(.*?)\}\s*-->"
|
||||
hashtag_pattern = r"#DANSWER_METADATA=\{(.*?)\}"
|
||||
|
||||
@@ -161,9 +180,13 @@ def read_text_file(
|
||||
errors: str = "replace",
|
||||
ignore_onyx_metadata: bool = True,
|
||||
) -> tuple[str, dict]:
|
||||
"""
|
||||
For plain text files. Optionally extracts Onyx metadata from the first line.
|
||||
"""
|
||||
metadata = {}
|
||||
file_content_raw = ""
|
||||
for ind, line in enumerate(file):
|
||||
# decode
|
||||
try:
|
||||
line = line.decode(encoding) if isinstance(line, bytes) else line
|
||||
except UnicodeDecodeError:
|
||||
@@ -173,131 +196,132 @@ def read_text_file(
|
||||
else line
|
||||
)
|
||||
|
||||
if ind == 0:
|
||||
metadata_or_none = (
|
||||
None if ignore_onyx_metadata else _extract_onyx_metadata(line)
|
||||
)
|
||||
if metadata_or_none is not None:
|
||||
metadata = metadata_or_none
|
||||
else:
|
||||
file_content_raw += line
|
||||
else:
|
||||
file_content_raw += line
|
||||
# optionally parse metadata in the first line
|
||||
if ind == 0 and not ignore_onyx_metadata:
|
||||
potential_meta = _extract_onyx_metadata(line)
|
||||
if potential_meta is not None:
|
||||
metadata = potential_meta
|
||||
continue
|
||||
|
||||
file_content_raw += line
|
||||
|
||||
return file_content_raw, metadata
|
||||
|
||||
|
||||
def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str:
|
||||
"""Extract text from a PDF file."""
|
||||
# Return only the extracted text from read_pdf_file
|
||||
text, _ = read_pdf_file(file, pdf_pass)
|
||||
"""
|
||||
Extract text from a PDF. For embedded images, a more complex approach is needed.
|
||||
This is a minimal approach returning text only.
|
||||
"""
|
||||
text, _, _ = read_pdf_file(file, pdf_pass)
|
||||
return text
|
||||
|
||||
|
||||
def read_pdf_file(
|
||||
file: IO[Any],
|
||||
pdf_pass: str | None = None,
|
||||
) -> tuple[str, dict]:
|
||||
metadata: Dict[str, Any] = {}
|
||||
file: IO[Any], pdf_pass: str | None = None, extract_images: bool = False
|
||||
) -> tuple[str, dict, list[tuple[bytes, str]]]:
|
||||
"""
|
||||
Returns the text, basic PDF metadata, and optionally extracted images.
|
||||
"""
|
||||
metadata: dict[str, Any] = {}
|
||||
extracted_images: list[tuple[bytes, str]] = []
|
||||
try:
|
||||
pdf_reader = PdfReader(file)
|
||||
|
||||
# If marked as encrypted and a password is provided, try to decrypt
|
||||
if pdf_reader.is_encrypted and pdf_pass is not None:
|
||||
decrypt_success = False
|
||||
if pdf_pass is not None:
|
||||
try:
|
||||
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
|
||||
except Exception:
|
||||
logger.error("Unable to decrypt pdf")
|
||||
try:
|
||||
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
|
||||
except Exception:
|
||||
logger.error("Unable to decrypt pdf")
|
||||
|
||||
if not decrypt_success:
|
||||
# By user request, keep files that are unreadable just so they
|
||||
# can be discoverable by title.
|
||||
return "", metadata
|
||||
return "", metadata, []
|
||||
elif pdf_reader.is_encrypted:
|
||||
logger.warning("No Password available to decrypt pdf, returning empty")
|
||||
return "", metadata
|
||||
logger.warning("No Password for an encrypted PDF, returning empty text.")
|
||||
return "", metadata, []
|
||||
|
||||
# Extract metadata from the PDF, removing leading '/' from keys if present
|
||||
# This standardizes the metadata keys for consistency
|
||||
metadata = {}
|
||||
# Basic PDF metadata
|
||||
if pdf_reader.metadata is not None:
|
||||
for key, value in pdf_reader.metadata.items():
|
||||
clean_key = key.lstrip("/")
|
||||
if isinstance(value, str) and value.strip():
|
||||
metadata[clean_key] = value
|
||||
|
||||
elif isinstance(value, list) and all(
|
||||
isinstance(item, str) for item in value
|
||||
):
|
||||
metadata[clean_key] = ", ".join(value)
|
||||
|
||||
return (
|
||||
TEXT_SECTION_SEPARATOR.join(
|
||||
page.extract_text() for page in pdf_reader.pages
|
||||
),
|
||||
metadata,
|
||||
text = TEXT_SECTION_SEPARATOR.join(
|
||||
page.extract_text() for page in pdf_reader.pages
|
||||
)
|
||||
|
||||
if extract_images:
|
||||
for page_num, page in enumerate(pdf_reader.pages):
|
||||
for image_file_object in page.images:
|
||||
image = Image.open(io.BytesIO(image_file_object.data))
|
||||
img_byte_arr = io.BytesIO()
|
||||
image.save(img_byte_arr, format=image.format)
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
|
||||
image_name = (
|
||||
f"page_{page_num + 1}_image_{image_file_object.name}."
|
||||
f"{image.format.lower() if image.format else 'png'}"
|
||||
)
|
||||
extracted_images.append((img_bytes, image_name))
|
||||
|
||||
return text, metadata, extracted_images
|
||||
|
||||
except PdfStreamError:
|
||||
logger.exception("PDF file is not a valid PDF")
|
||||
logger.exception("Invalid PDF file")
|
||||
except Exception:
|
||||
logger.exception("Failed to read PDF")
|
||||
|
||||
# File is still discoverable by title
|
||||
# but the contents are not included as they cannot be parsed
|
||||
return "", metadata
|
||||
return "", metadata, []
|
||||
|
||||
|
||||
def docx_to_text(file: IO[Any]) -> str:
|
||||
def is_simple_table(table: docx.table.Table) -> bool:
|
||||
for row in table.rows:
|
||||
# No omitted cells
|
||||
if row.grid_cols_before > 0 or row.grid_cols_after > 0:
|
||||
return False
|
||||
|
||||
# No nested tables
|
||||
if any(cell.tables for cell in row.cells):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def extract_cell_text(cell: docx.table._Cell) -> str:
|
||||
cell_paragraphs = [para.text.strip() for para in cell.paragraphs]
|
||||
return " ".join(p for p in cell_paragraphs if p) or "N/A"
|
||||
|
||||
def docx_to_text_and_images(
|
||||
file: IO[Any],
|
||||
) -> Tuple[str, List[Tuple[bytes, str]]]:
|
||||
"""
|
||||
Extract text from a docx. If embed_images=True, also extract inline images.
|
||||
Return (text_content, list_of_images).
|
||||
"""
|
||||
paragraphs = []
|
||||
embedded_images: List[Tuple[bytes, str]] = []
|
||||
|
||||
doc = docx.Document(file)
|
||||
for item in doc.iter_inner_content():
|
||||
if isinstance(item, docx.text.paragraph.Paragraph):
|
||||
paragraphs.append(item.text)
|
||||
|
||||
elif isinstance(item, docx.table.Table):
|
||||
if not item.rows or not is_simple_table(item):
|
||||
continue
|
||||
# Grab text from paragraphs
|
||||
for paragraph in doc.paragraphs:
|
||||
paragraphs.append(paragraph.text)
|
||||
|
||||
# Every row is a new line, joined with a single newline
|
||||
table_content = "\n".join(
|
||||
[
|
||||
",\t".join(extract_cell_text(cell) for cell in row.cells)
|
||||
for row in item.rows
|
||||
]
|
||||
)
|
||||
paragraphs.append(table_content)
|
||||
# Reset position so we can re-load the doc (python-docx has read the stream)
|
||||
# Note: if python-docx has fully consumed the stream, you may need to open it again from memory.
|
||||
# For large docs, a more robust approach is needed.
|
||||
# This is a simplified example.
|
||||
|
||||
# Docx already has good spacing between paragraphs
|
||||
return "\n".join(paragraphs)
|
||||
for rel_id, rel in doc.part.rels.items():
|
||||
if "image" in rel.reltype:
|
||||
# image is typically in rel.target_part.blob
|
||||
image_bytes = rel.target_part.blob
|
||||
image_name = rel.target_part.partname
|
||||
# store
|
||||
embedded_images.append((image_bytes, os.path.basename(str(image_name))))
|
||||
|
||||
text_content = "\n".join(paragraphs)
|
||||
return text_content, embedded_images
|
||||
|
||||
|
||||
def pptx_to_text(file: IO[Any]) -> str:
|
||||
presentation = pptx.Presentation(file)
|
||||
text_content = []
|
||||
for slide_number, slide in enumerate(presentation.slides, start=1):
|
||||
extracted_text = f"\nSlide {slide_number}:\n"
|
||||
slide_text = f"\nSlide {slide_number}:\n"
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text"):
|
||||
extracted_text += shape.text + "\n"
|
||||
text_content.append(extracted_text)
|
||||
slide_text += shape.text + "\n"
|
||||
text_content.append(slide_text)
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
|
||||
|
||||
@@ -305,18 +329,21 @@ def xlsx_to_text(file: IO[Any]) -> str:
|
||||
workbook = openpyxl.load_workbook(file, read_only=True)
|
||||
text_content = []
|
||||
for sheet in workbook.worksheets:
|
||||
sheet_string = "\n".join(
|
||||
",".join(map(str, row))
|
||||
for row in sheet.iter_rows(min_row=1, values_only=True)
|
||||
)
|
||||
text_content.append(sheet_string)
|
||||
rows = []
|
||||
for row in sheet.iter_rows(min_row=1, values_only=True):
|
||||
row_str = ",".join(str(cell) if cell is not None else "" for cell in row)
|
||||
rows.append(row_str)
|
||||
sheet_str = "\n".join(rows)
|
||||
text_content.append(sheet_str)
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
|
||||
|
||||
def eml_to_text(file: IO[Any]) -> str:
|
||||
text_file = io.TextIOWrapper(file, encoding=detect_encoding(file))
|
||||
encoding = detect_encoding(file)
|
||||
text_file = io.TextIOWrapper(file, encoding=encoding)
|
||||
parser = EmailParser()
|
||||
message = parser.parse(text_file)
|
||||
|
||||
text_content = []
|
||||
for part in message.walk():
|
||||
if part.get_content_type().startswith("text/plain"):
|
||||
@@ -342,8 +369,8 @@ def epub_to_text(file: IO[Any]) -> str:
|
||||
|
||||
def file_io_to_text(file: IO[Any]) -> str:
|
||||
encoding = detect_encoding(file)
|
||||
file_content_raw, _ = read_text_file(file, encoding=encoding)
|
||||
return file_content_raw
|
||||
file_content, _ = read_text_file(file, encoding=encoding)
|
||||
return file_content
|
||||
|
||||
|
||||
def extract_file_text(
|
||||
@@ -352,9 +379,13 @@ def extract_file_text(
|
||||
break_on_unprocessable: bool = True,
|
||||
extension: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Legacy function that returns *only text*, ignoring embedded images.
|
||||
For backward-compatibility in code that only wants text.
|
||||
"""
|
||||
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
|
||||
".pdf": pdf_to_text,
|
||||
".docx": docx_to_text,
|
||||
".docx": lambda f: docx_to_text_and_images(f)[0], # no images
|
||||
".pptx": pptx_to_text,
|
||||
".xlsx": xlsx_to_text,
|
||||
".eml": eml_to_text,
|
||||
@@ -368,24 +399,23 @@ def extract_file_text(
|
||||
return unstructured_to_text(file, file_name)
|
||||
except Exception as unstructured_error:
|
||||
logger.error(
|
||||
f"Failed to process with Unstructured: {str(unstructured_error)}. Falling back to normal processing."
|
||||
f"Failed to process with Unstructured: {str(unstructured_error)}. "
|
||||
"Falling back to normal processing."
|
||||
)
|
||||
# Fall through to normal processing
|
||||
final_extension: str
|
||||
if file_name or extension:
|
||||
if extension is not None:
|
||||
final_extension = extension
|
||||
elif file_name is not None:
|
||||
final_extension = get_file_ext(file_name)
|
||||
if extension is None:
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
if is_valid_file_ext(final_extension):
|
||||
return extension_to_function.get(final_extension, file_io_to_text)(file)
|
||||
if is_valid_file_ext(extension):
|
||||
func = extension_to_function.get(extension, file_io_to_text)
|
||||
file.seek(0)
|
||||
return func(file)
|
||||
|
||||
# Either the file somehow has no name or the extension is not one that we recognize
|
||||
# If unknown extension, maybe it's a text file
|
||||
file.seek(0)
|
||||
if is_text_file(file):
|
||||
return file_io_to_text(file)
|
||||
|
||||
raise ValueError("Unknown file extension and unknown text encoding")
|
||||
raise ValueError("Unknown file extension or not recognized as text data")
|
||||
|
||||
except Exception as e:
|
||||
if break_on_unprocessable:
|
||||
@@ -396,20 +426,93 @@ def extract_file_text(
|
||||
return ""
|
||||
|
||||
|
||||
def extract_text_and_images(
|
||||
file: IO[Any],
|
||||
file_name: str,
|
||||
pdf_pass: str | None = None,
|
||||
) -> Tuple[str, List[Tuple[bytes, str]]]:
|
||||
"""
|
||||
Primary new function for the updated connector.
|
||||
Returns (text_content, [(embedded_img_bytes, embedded_img_name), ...]).
|
||||
"""
|
||||
|
||||
try:
|
||||
# Attempt unstructured if env var is set
|
||||
if get_unstructured_api_key():
|
||||
# If the user doesn't want embedded images, unstructured is fine
|
||||
file.seek(0)
|
||||
text_content = unstructured_to_text(file, file_name)
|
||||
return (text_content, [])
|
||||
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
# docx example for embedded images
|
||||
if extension == ".docx":
|
||||
file.seek(0)
|
||||
text_content, images = docx_to_text_and_images(file)
|
||||
return (text_content, images)
|
||||
|
||||
# PDF example: we do not show complicated PDF image extraction here
|
||||
# so we simply extract text for now and skip images.
|
||||
if extension == ".pdf":
|
||||
file.seek(0)
|
||||
text_content, _, images = read_pdf_file(file, pdf_pass, extract_images=True)
|
||||
return (text_content, images)
|
||||
|
||||
# For PPTX, XLSX, EML, etc., we do not show embedded image logic here.
|
||||
# You can do something similar to docx if needed.
|
||||
if extension == ".pptx":
|
||||
file.seek(0)
|
||||
return (pptx_to_text(file), [])
|
||||
|
||||
if extension == ".xlsx":
|
||||
file.seek(0)
|
||||
return (xlsx_to_text(file), [])
|
||||
|
||||
if extension == ".eml":
|
||||
file.seek(0)
|
||||
return (eml_to_text(file), [])
|
||||
|
||||
if extension == ".epub":
|
||||
file.seek(0)
|
||||
return (epub_to_text(file), [])
|
||||
|
||||
if extension == ".html":
|
||||
file.seek(0)
|
||||
return (parse_html_page_basic(file), [])
|
||||
|
||||
# If we reach here and it's a recognized text extension
|
||||
if is_text_file_extension(file_name):
|
||||
file.seek(0)
|
||||
encoding = detect_encoding(file)
|
||||
text_content_raw, _ = read_text_file(
|
||||
file, encoding=encoding, ignore_onyx_metadata=False
|
||||
)
|
||||
return (text_content_raw, [])
|
||||
|
||||
# If it's an image file or something else, we do not parse embedded images from them
|
||||
# just return empty text
|
||||
file.seek(0)
|
||||
return ("", [])
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to extract text/images from {file_name}: {e}")
|
||||
return ("", [])
|
||||
|
||||
|
||||
def convert_docx_to_txt(
|
||||
file: UploadFile, file_store: FileStore, file_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Helper to convert docx to a .txt file in the same filestore.
|
||||
"""
|
||||
file.file.seek(0)
|
||||
docx_content = file.file.read()
|
||||
doc = Document(BytesIO(docx_content))
|
||||
doc = DocxDocument(BytesIO(docx_content))
|
||||
|
||||
# Extract text from the document
|
||||
full_text = []
|
||||
for para in doc.paragraphs:
|
||||
full_text.append(para.text)
|
||||
|
||||
# Join the extracted text
|
||||
text_content = "\n".join(full_text)
|
||||
all_paras = [p.text for p in doc.paragraphs]
|
||||
text_content = "\n".join(all_paras)
|
||||
|
||||
txt_file_path = docx_to_txt_filename(file_path)
|
||||
file_store.save_file(
|
||||
@@ -422,7 +525,4 @@ def convert_docx_to_txt(
|
||||
|
||||
|
||||
def docx_to_txt_filename(file_path: str) -> str:
|
||||
"""
|
||||
Convert a .docx file path to its corresponding .txt file path.
|
||||
"""
|
||||
return file_path.rsplit(".", 1)[0] + ".txt"
|
||||
|
||||
46
backend/onyx/file_processing/file_validation.py
Normal file
46
backend/onyx/file_processing/file_validation.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Centralized file type validation utilities.
|
||||
"""
|
||||
# Standard image MIME types supported by most vision LLMs
|
||||
IMAGE_MIME_TYPES = [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
]
|
||||
|
||||
# Image types that should be excluded from processing
|
||||
EXCLUDED_IMAGE_TYPES = [
|
||||
"image/bmp",
|
||||
"image/tiff",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
]
|
||||
|
||||
|
||||
def is_valid_image_type(mime_type: str) -> bool:
|
||||
"""
|
||||
Check if mime_type is a valid image type.
|
||||
|
||||
Args:
|
||||
mime_type: The MIME type to check
|
||||
|
||||
Returns:
|
||||
True if the MIME type is a valid image type, False otherwise
|
||||
"""
|
||||
if not mime_type:
|
||||
return False
|
||||
return mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
|
||||
|
||||
def is_supported_by_vision_llm(mime_type: str) -> bool:
|
||||
"""
|
||||
Check if this image type can be processed by vision LLMs.
|
||||
|
||||
Args:
|
||||
mime_type: The MIME type to check
|
||||
|
||||
Returns:
|
||||
True if the MIME type is supported by vision LLMs, False otherwise
|
||||
"""
|
||||
return mime_type in IMAGE_MIME_TYPES
|
||||
129
backend/onyx/file_processing/image_summarization.py
Normal file
129
backend/onyx/file_processing/image_summarization.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from PIL import Image
|
||||
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
||||
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_USER_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def prepare_image_bytes(image_data: bytes) -> str:
|
||||
"""Prepare image bytes for summarization.
|
||||
Resizes image if it's larger than 20MB. Encodes image as a base64 string."""
|
||||
image_data = _resize_image_if_needed(image_data)
|
||||
|
||||
# encode image (base64)
|
||||
encoded_image = _encode_image_for_llm_prompt(image_data)
|
||||
|
||||
return encoded_image
|
||||
|
||||
|
||||
def summarize_image_pipeline(
|
||||
llm: LLM,
|
||||
image_data: bytes,
|
||||
query: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> str:
|
||||
"""Pipeline to generate a summary of an image.
|
||||
Resizes images if it is bigger than 20MB. Encodes image as a base64 string.
|
||||
And finally uses the Default LLM to generate a textual summary of the image."""
|
||||
# resize image if it's bigger than 20MB
|
||||
encoded_image = prepare_image_bytes(image_data)
|
||||
|
||||
summary = _summarize_image(
|
||||
encoded_image,
|
||||
llm,
|
||||
query,
|
||||
system_prompt,
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
def summarize_image_with_error_handling(
|
||||
llm: LLM | None,
|
||||
image_data: bytes,
|
||||
context_name: str,
|
||||
system_prompt: str = IMAGE_SUMMARIZATION_SYSTEM_PROMPT,
|
||||
user_prompt_template: str = IMAGE_SUMMARIZATION_USER_PROMPT,
|
||||
) -> str | None:
|
||||
"""Wrapper function that handles error cases and configuration consistently.
|
||||
|
||||
Args:
|
||||
llm: The LLM with vision capabilities to use for summarization
|
||||
image_data: The raw image bytes
|
||||
context_name: Name or title of the image for context
|
||||
system_prompt: System prompt to use for the LLM
|
||||
user_prompt_template: Template for the user prompt, should contain {title} placeholder
|
||||
|
||||
Returns:
|
||||
The image summary text, or None if summarization failed or is disabled
|
||||
"""
|
||||
if llm is None:
|
||||
return None
|
||||
|
||||
user_prompt = user_prompt_template.format(title=context_name)
|
||||
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)
|
||||
|
||||
|
||||
def _summarize_image(
|
||||
encoded_image: str,
|
||||
llm: LLM,
|
||||
query: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
) -> str:
|
||||
"""Use default LLM (if it is multimodal) to generate a summary of an image."""
|
||||
|
||||
messages: list[BaseMessage] = []
|
||||
|
||||
if system_prompt:
|
||||
messages.append(SystemMessage(content=system_prompt))
|
||||
|
||||
messages.append(
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": query},
|
||||
{"type": "image_url", "image_url": {"url": encoded_image}},
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
return message_to_string(llm.invoke(messages))
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Summarization failed. Messages: {messages}") from e
|
||||
|
||||
|
||||
def _encode_image_for_llm_prompt(image_data: bytes) -> str:
|
||||
"""Getting the base64 string."""
|
||||
base64_encoded_data = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
return f"data:image/jpeg;base64,{base64_encoded_data}"
|
||||
|
||||
|
||||
def _resize_image_if_needed(image_data: bytes, max_size_mb: int = 20) -> bytes:
|
||||
"""Resize image if it's larger than the specified max size in MB."""
|
||||
max_size_bytes = max_size_mb * 1024 * 1024
|
||||
|
||||
if len(image_data) > max_size_bytes:
|
||||
with Image.open(BytesIO(image_data)) as img:
|
||||
# Reduce dimensions for better size reduction
|
||||
img.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
|
||||
output = BytesIO()
|
||||
|
||||
# Save with lower quality for compression
|
||||
img.save(output, format="JPEG", quality=85)
|
||||
resized_data = output.getvalue()
|
||||
|
||||
return resized_data
|
||||
|
||||
return image_data
|
||||
70
backend/onyx/file_processing/image_utils.py
Normal file
70
backend/onyx/file_processing/image_utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.db.pg_file_store import save_bytes_to_pgfilestore
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def store_image_and_create_section(
|
||||
db_session: Session,
|
||||
image_data: bytes,
|
||||
file_name: str,
|
||||
display_name: str,
|
||||
media_type: str = "image/unknown",
|
||||
llm: LLM | None = None,
|
||||
file_origin: FileOrigin = FileOrigin.OTHER,
|
||||
) -> Tuple[Section, str | None]:
|
||||
"""
|
||||
Stores an image in PGFileStore and creates a Section object with optional summarization.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
image_data: Raw image bytes
|
||||
file_name: Base identifier for the file
|
||||
display_name: Human-readable name for the image
|
||||
media_type: MIME type of the image
|
||||
llm: Optional LLM with vision capabilities for summarization
|
||||
file_origin: Origin of the file (e.g., CONFLUENCE, GOOGLE_DRIVE, etc.)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Section object with image reference and optional summary text
|
||||
- The file_name in PGFileStore or None if storage failed
|
||||
"""
|
||||
# Storage logic
|
||||
stored_file_name = None
|
||||
try:
|
||||
pgfilestore = save_bytes_to_pgfilestore(
|
||||
db_session=db_session,
|
||||
raw_bytes=image_data,
|
||||
media_type=media_type,
|
||||
identifier=file_name,
|
||||
display_name=display_name,
|
||||
file_origin=file_origin,
|
||||
)
|
||||
stored_file_name = pgfilestore.file_name
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store image: {e}")
|
||||
if not CONTINUE_ON_CONNECTOR_FAILURE:
|
||||
raise
|
||||
return Section(text=""), None
|
||||
|
||||
# Summarization logic
|
||||
summary_text = ""
|
||||
if llm:
|
||||
summary_text = (
|
||||
summarize_image_with_error_handling(llm, image_data, display_name) or ""
|
||||
)
|
||||
|
||||
return (
|
||||
Section(text=summary_text, image_file_name=stored_file_name),
|
||||
stored_file_name,
|
||||
)
|
||||
@@ -23,12 +23,9 @@ from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
|
||||
CHUNK_OVERLAP = 0
|
||||
# Fairly arbitrary numbers but the general concept is we don't want the title/metadata to
|
||||
# overwhelm the actual contents of the chunk
|
||||
# For example in a rare case, this could be 128 tokens for the 512 chunk and title prefix
|
||||
# could be another 128 tokens leaving 256 for the actual contents
|
||||
MAX_METADATA_PERCENTAGE = 0.25
|
||||
CHUNK_MIN_CONTENT = 256
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -36,16 +33,8 @@ def _get_metadata_suffix_for_document_index(
|
||||
metadata: dict[str, str | list[str]], include_separator: bool = False
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Returns the metadata as a natural language string representation with all of the keys and values for the vector embedding
|
||||
and a string of all of the values for the keyword search
|
||||
|
||||
For example, if we have the following metadata:
|
||||
{
|
||||
"author": "John Doe",
|
||||
"space": "Engineering"
|
||||
}
|
||||
The vector embedding string should include the relation between the key and value wheres as for keyword we only want John Doe
|
||||
and Engineering. The keys are repeat and much more noisy.
|
||||
Returns the metadata as a natural language string representation with all of the keys and values
|
||||
for the vector embedding and a string of all of the values for the keyword search.
|
||||
"""
|
||||
if not metadata:
|
||||
return "", ""
|
||||
@@ -74,12 +63,17 @@ def _get_metadata_suffix_for_document_index(
|
||||
|
||||
|
||||
def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwareChunk:
|
||||
"""
|
||||
Combines multiple DocAwareChunks into one large chunk (for “multipass” mode),
|
||||
appending the content and adjusting source_links accordingly.
|
||||
"""
|
||||
merged_chunk = DocAwareChunk(
|
||||
source_document=chunks[0].source_document,
|
||||
chunk_id=chunks[0].chunk_id,
|
||||
blurb=chunks[0].blurb,
|
||||
content=chunks[0].content,
|
||||
source_links=chunks[0].source_links or {},
|
||||
image_file_name=None,
|
||||
section_continuation=(chunks[0].chunk_id > 0),
|
||||
title_prefix=chunks[0].title_prefix,
|
||||
metadata_suffix_semantic=chunks[0].metadata_suffix_semantic,
|
||||
@@ -103,6 +97,9 @@ def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwar
|
||||
|
||||
|
||||
def generate_large_chunks(chunks: list[DocAwareChunk]) -> list[DocAwareChunk]:
|
||||
"""
|
||||
Generates larger “grouped” chunks by combining sets of smaller chunks.
|
||||
"""
|
||||
large_chunks = []
|
||||
for idx, i in enumerate(range(0, len(chunks), LARGE_CHUNK_RATIO)):
|
||||
chunk_group = chunks[i : i + LARGE_CHUNK_RATIO]
|
||||
@@ -172,23 +169,60 @@ class Chunker:
|
||||
while start < total_tokens:
|
||||
end = min(start + content_token_limit, total_tokens)
|
||||
token_chunk = tokens[start:end]
|
||||
# Join the tokens to reconstruct the text
|
||||
chunk_text = " ".join(token_chunk)
|
||||
chunks.append(chunk_text)
|
||||
start = end
|
||||
return chunks
|
||||
|
||||
def _extract_blurb(self, text: str) -> str:
|
||||
"""
|
||||
Extract a short blurb from the text (first chunk of size `blurb_size`).
|
||||
"""
|
||||
texts = self.blurb_splitter.split_text(text)
|
||||
if not texts:
|
||||
return ""
|
||||
return texts[0]
|
||||
|
||||
def _get_mini_chunk_texts(self, chunk_text: str) -> list[str] | None:
|
||||
"""
|
||||
For “multipass” mode: additional sub-chunks (mini-chunks) for use in certain embeddings.
|
||||
"""
|
||||
if self.mini_chunk_splitter and chunk_text.strip():
|
||||
return self.mini_chunk_splitter.split_text(chunk_text)
|
||||
return None
|
||||
|
||||
# ADDED: extra param image_url to store in the chunk
|
||||
def _create_chunk(
|
||||
self,
|
||||
document: Document,
|
||||
chunks_list: list[DocAwareChunk],
|
||||
text: str,
|
||||
links: dict[int, str],
|
||||
is_continuation: bool = False,
|
||||
title_prefix: str = "",
|
||||
metadata_suffix_semantic: str = "",
|
||||
metadata_suffix_keyword: str = "",
|
||||
image_file_name: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper to create a new DocAwareChunk, append it to chunks_list.
|
||||
"""
|
||||
new_chunk = DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks_list),
|
||||
blurb=self._extract_blurb(text),
|
||||
content=text,
|
||||
source_links=links or {0: ""},
|
||||
image_file_name=image_file_name,
|
||||
section_continuation=is_continuation,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=self._get_mini_chunk_texts(text),
|
||||
large_chunk_id=None,
|
||||
)
|
||||
chunks_list.append(new_chunk)
|
||||
|
||||
def _chunk_document(
|
||||
self,
|
||||
document: Document,
|
||||
@@ -198,122 +232,156 @@ class Chunker:
|
||||
content_token_limit: int,
|
||||
) -> list[DocAwareChunk]:
|
||||
"""
|
||||
Loops through sections of the document, adds metadata and converts them into chunks.
|
||||
Loops through sections of the document, converting them into one or more chunks.
|
||||
If a section has an image_link, we treat it as a dedicated chunk.
|
||||
"""
|
||||
|
||||
chunks: list[DocAwareChunk] = []
|
||||
link_offsets: dict[int, str] = {}
|
||||
chunk_text = ""
|
||||
|
||||
def _create_chunk(
|
||||
text: str,
|
||||
links: dict[int, str],
|
||||
is_continuation: bool = False,
|
||||
) -> DocAwareChunk:
|
||||
return DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=self._extract_blurb(text),
|
||||
content=text,
|
||||
source_links=links or {0: ""},
|
||||
section_continuation=is_continuation,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=self._get_mini_chunk_texts(text),
|
||||
large_chunk_id=None,
|
||||
)
|
||||
|
||||
section_link_text: str
|
||||
|
||||
for section_idx, section in enumerate(document.sections):
|
||||
section_text = clean_text(section.text)
|
||||
section_link_text = section.link or ""
|
||||
# If there is no useful content, not even the title, just drop it
|
||||
# ADDED: if the Section has an image link
|
||||
image_url = section.image_file_name
|
||||
|
||||
# If there is no useful content, skip
|
||||
if not section_text and (not document.title or section_idx > 0):
|
||||
# If a section is empty and the document has no title, we can just drop it. We return a list of
|
||||
# DocAwareChunks where each one contains the necessary information needed down the line for indexing.
|
||||
# There is no concern about dropping whole documents from this list, it should not cause any indexing failures.
|
||||
logger.warning(
|
||||
f"Skipping section {section.text} from document "
|
||||
f"{document.semantic_identifier} due to empty text after cleaning "
|
||||
f"with link {section_link_text}"
|
||||
f"Skipping empty or irrelevant section in doc "
|
||||
f"{document.semantic_identifier}, link={section_link_text}"
|
||||
)
|
||||
continue
|
||||
|
||||
# CASE 1: If this is an image section, force a separate chunk
|
||||
if image_url:
|
||||
# First, if we have any partially built text chunk, finalize it
|
||||
if chunk_text.strip():
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets,
|
||||
is_continuation=False,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
chunk_text = ""
|
||||
link_offsets = {}
|
||||
|
||||
# Create a chunk specifically for this image
|
||||
# (If the section has text describing the image, use that as content)
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
section_text,
|
||||
links={0: section_link_text}
|
||||
if section_link_text
|
||||
else {}, # No text offsets needed for images
|
||||
image_file_name=image_url,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
# Continue to next section
|
||||
continue
|
||||
|
||||
# CASE 2: Normal text section
|
||||
section_token_count = len(self.tokenizer.tokenize(section_text))
|
||||
|
||||
# Large sections are considered self-contained/unique
|
||||
# Therefore, they start a new chunk and are not concatenated
|
||||
# at the end by other sections
|
||||
# If the section is large on its own, split it separately
|
||||
if section_token_count > content_token_limit:
|
||||
if chunk_text:
|
||||
chunks.append(_create_chunk(chunk_text, link_offsets))
|
||||
link_offsets = {}
|
||||
if chunk_text.strip():
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets,
|
||||
False,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
)
|
||||
chunk_text = ""
|
||||
link_offsets = {}
|
||||
|
||||
split_texts = self.chunk_splitter.split_text(section_text)
|
||||
|
||||
for i, split_text in enumerate(split_texts):
|
||||
# If even the split_text is bigger than strict limit, further split
|
||||
if (
|
||||
STRICT_CHUNK_TOKEN_LIMIT
|
||||
and
|
||||
# Tokenizer only runs if STRICT_CHUNK_TOKEN_LIMIT is true
|
||||
len(self.tokenizer.tokenize(split_text)) > content_token_limit
|
||||
and len(self.tokenizer.tokenize(split_text))
|
||||
> content_token_limit
|
||||
):
|
||||
# If STRICT_CHUNK_TOKEN_LIMIT is true, manually check
|
||||
# the token count of each split text to ensure it is
|
||||
# not larger than the content_token_limit
|
||||
smaller_chunks = self._split_oversized_chunk(
|
||||
split_text, content_token_limit
|
||||
)
|
||||
for i, small_chunk in enumerate(smaller_chunks):
|
||||
chunks.append(
|
||||
_create_chunk(
|
||||
text=small_chunk,
|
||||
links={0: section_link_text},
|
||||
is_continuation=(i != 0),
|
||||
)
|
||||
for j, small_chunk in enumerate(smaller_chunks):
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
small_chunk,
|
||||
{0: section_link_text},
|
||||
is_continuation=(j != 0),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
else:
|
||||
chunks.append(
|
||||
_create_chunk(
|
||||
text=split_text,
|
||||
links={0: section_link_text},
|
||||
is_continuation=(i != 0),
|
||||
)
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
split_text,
|
||||
{0: section_link_text},
|
||||
is_continuation=(i != 0),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
# If we can still fit this section into the current chunk, do so
|
||||
current_token_count = len(self.tokenizer.tokenize(chunk_text))
|
||||
current_offset = len(shared_precompare_cleanup(chunk_text))
|
||||
# In the case where the whole section is shorter than a chunk, either add
|
||||
# to chunk or start a new one
|
||||
next_section_tokens = (
|
||||
len(self.tokenizer.tokenize(SECTION_SEPARATOR)) + section_token_count
|
||||
)
|
||||
|
||||
if next_section_tokens + current_token_count <= content_token_limit:
|
||||
if chunk_text:
|
||||
chunk_text += SECTION_SEPARATOR
|
||||
chunk_text += section_text
|
||||
link_offsets[current_offset] = section_link_text
|
||||
else:
|
||||
chunks.append(_create_chunk(chunk_text, link_offsets))
|
||||
# finalize the existing chunk
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets,
|
||||
False,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
)
|
||||
# start a new chunk
|
||||
link_offsets = {0: section_link_text}
|
||||
chunk_text = section_text
|
||||
|
||||
# Once we hit the end, if we're still in the process of building a chunk, add what we have.
|
||||
# If there is only whitespace left then don't include it. If there are no chunks at all
|
||||
# from the doc, we can just create a single chunk with the title.
|
||||
# finalize any leftover text chunk
|
||||
if chunk_text.strip() or not chunks:
|
||||
chunks.append(
|
||||
_create_chunk(
|
||||
chunk_text,
|
||||
link_offsets or {0: section_link_text},
|
||||
)
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
chunk_text,
|
||||
link_offsets or {0: ""}, # safe default
|
||||
False,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
)
|
||||
|
||||
# If the chunk does not have any useable content, it will not be indexed
|
||||
return chunks
|
||||
|
||||
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
|
||||
@@ -321,10 +389,12 @@ class Chunker:
|
||||
if document.source == DocumentSource.GMAIL:
|
||||
logger.debug(f"Chunking {document.semantic_identifier}")
|
||||
|
||||
# Title prep
|
||||
title = self._extract_blurb(document.get_title_for_document_index() or "")
|
||||
title_prefix = title + RETURN_SEPARATOR if title else ""
|
||||
title_tokens = len(self.tokenizer.tokenize(title_prefix))
|
||||
|
||||
# Metadata prep
|
||||
metadata_suffix_semantic = ""
|
||||
metadata_suffix_keyword = ""
|
||||
metadata_tokens = 0
|
||||
@@ -337,19 +407,20 @@ class Chunker:
|
||||
)
|
||||
metadata_tokens = len(self.tokenizer.tokenize(metadata_suffix_semantic))
|
||||
|
||||
# If metadata is too large, skip it in the semantic content
|
||||
if metadata_tokens >= self.chunk_token_limit * MAX_METADATA_PERCENTAGE:
|
||||
# Note: we can keep the keyword suffix even if the semantic suffix is too long to fit in the model
|
||||
# context, there is no limit for the keyword component
|
||||
metadata_suffix_semantic = ""
|
||||
metadata_tokens = 0
|
||||
|
||||
# Adjust content token limit to accommodate title + metadata
|
||||
content_token_limit = self.chunk_token_limit - title_tokens - metadata_tokens
|
||||
# If there is not enough context remaining then just index the chunk with no prefix/suffix
|
||||
if content_token_limit <= CHUNK_MIN_CONTENT:
|
||||
# Not enough space left, so revert to full chunk without the prefix
|
||||
content_token_limit = self.chunk_token_limit
|
||||
title_prefix = ""
|
||||
metadata_suffix_semantic = ""
|
||||
|
||||
# Chunk the document
|
||||
normal_chunks = self._chunk_document(
|
||||
document,
|
||||
title_prefix,
|
||||
@@ -358,6 +429,7 @@ class Chunker:
|
||||
content_token_limit,
|
||||
)
|
||||
|
||||
# Optional “multipass” large chunk creation
|
||||
if self.enable_multipass and self.enable_large_chunks:
|
||||
large_chunks = generate_large_chunks(normal_chunks)
|
||||
normal_chunks.extend(large_chunks)
|
||||
@@ -371,9 +443,8 @@ class Chunker:
|
||||
"""
|
||||
final_chunks: list[DocAwareChunk] = []
|
||||
for document in documents:
|
||||
if self.callback:
|
||||
if self.callback.should_stop():
|
||||
raise RuntimeError("Chunker.chunk: Stop signal detected")
|
||||
if self.callback and self.callback.should_stop():
|
||||
raise RuntimeError("Chunker.chunk: Stop signal detected")
|
||||
|
||||
chunks = self._handle_single_document(document)
|
||||
final_chunks.extend(chunks)
|
||||
|
||||
@@ -38,6 +38,7 @@ class IndexingEmbedder(ABC):
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
deployment_name: str | None,
|
||||
reduced_dimension: int | None,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
@@ -60,6 +61,7 @@ class IndexingEmbedder(ABC):
|
||||
api_url=api_url,
|
||||
api_version=api_version,
|
||||
deployment_name=deployment_name,
|
||||
reduced_dimension=reduced_dimension,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
@@ -87,6 +89,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
reduced_dimension: int | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -99,6 +102,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
api_url,
|
||||
api_version,
|
||||
deployment_name,
|
||||
reduced_dimension,
|
||||
callback,
|
||||
)
|
||||
|
||||
@@ -219,6 +223,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
api_url=search_settings.api_url,
|
||||
api_version=search_settings.api_version,
|
||||
deployment_name=search_settings.deployment_name,
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import Field
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
@@ -28,6 +29,7 @@ class BaseChunk(BaseModel):
|
||||
content: str
|
||||
# Holds the link and the offsets into the raw Chunk text
|
||||
source_links: dict[int, str] | None
|
||||
image_file_name: str | None
|
||||
# True if this Chunk's start is not at the start of a Section
|
||||
section_continuation: bool
|
||||
|
||||
@@ -143,10 +145,20 @@ class IndexingSetting(EmbeddingModelDetail):
|
||||
model_dim: int
|
||||
index_name: str | None
|
||||
multipass_indexing: bool
|
||||
embedding_precision: EmbeddingPrecision
|
||||
reduced_dimension: int | None = None
|
||||
|
||||
background_reindex_enabled: bool = True
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
@property
|
||||
def final_embedding_dim(self) -> int:
|
||||
if self.reduced_dimension:
|
||||
return self.reduced_dimension
|
||||
return self.model_dim
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, search_settings: "SearchSettings") -> "IndexingSetting":
|
||||
return cls(
|
||||
@@ -158,6 +170,9 @@ class IndexingSetting(EmbeddingModelDetail):
|
||||
provider_type=search_settings.provider_type,
|
||||
index_name=search_settings.index_name,
|
||||
multipass_indexing=search_settings.multipass_indexing,
|
||||
embedding_precision=search_settings.embedding_precision,
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
background_reindex_enabled=search_settings.background_reindex_enabled,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,12 +14,14 @@ from onyx.db.models import KVStore
|
||||
from onyx.key_value_store.interface import KeyValueStore
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -43,9 +45,7 @@ class PgRedisKVStore(KeyValueStore):
|
||||
with Session(engine, expire_on_commit=False) as session:
|
||||
if MULTI_TENANT:
|
||||
if self.tenant_id == POSTGRES_DEFAULT_SCHEMA:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User must authenticate"
|
||||
)
|
||||
raise BasicAuthenticationError(detail="User must authenticate")
|
||||
if not is_valid_schema_name(self.tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
|
||||
@@ -6,12 +6,14 @@ from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_provider
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.chat_llm import DefaultMultiLLM
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.utils.headers import build_llm_extra_headers
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
@@ -86,6 +88,48 @@ def get_llms_for_persona(
|
||||
return _create_llm(model), _create_llm(fast_model)
|
||||
|
||||
|
||||
def get_default_llm_with_vision(
|
||||
timeout: int | None = None,
|
||||
temperature: float | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM | None:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise GenAIDisabledException()
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
|
||||
if not llm_providers:
|
||||
return None
|
||||
|
||||
for provider in llm_providers:
|
||||
model_name = provider.default_model_name
|
||||
fast_model_name = (
|
||||
provider.fast_default_model_name or provider.default_model_name
|
||||
)
|
||||
|
||||
if not model_name or not fast_model_name:
|
||||
continue
|
||||
|
||||
if model_supports_image_input(model_name, provider.provider):
|
||||
return get_llm(
|
||||
provider=provider.provider,
|
||||
model=model_name,
|
||||
deployment_name=provider.deployment_name,
|
||||
api_key=provider.api_key,
|
||||
api_base=provider.api_base,
|
||||
api_version=provider.api_version,
|
||||
custom_config=provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
|
||||
raise ValueError("No LLM provider found that supports image input")
|
||||
|
||||
|
||||
def get_default_llms(
|
||||
timeout: int | None = None,
|
||||
temperature: float | None = None,
|
||||
|
||||
@@ -51,7 +51,6 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
|
||||
from onyx.server.documents.connector import router as connector_router
|
||||
from onyx.server.documents.credential import router as credential_router
|
||||
from onyx.server.documents.document import router as document_router
|
||||
from onyx.server.documents.standard_oauth import router as oauth_router
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.folder.api import router as folder_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
@@ -323,7 +322,6 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, long_term_logs_router)
|
||||
include_router_with_global_prefix_prepended(application, api_key_router)
|
||||
include_router_with_global_prefix_prepended(application, oauth_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
|
||||
@@ -89,6 +89,7 @@ class EmbeddingModel:
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
api_version: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
reduced_dimension: int | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.provider_type = provider_type
|
||||
@@ -100,6 +101,7 @@ class EmbeddingModel:
|
||||
self.api_url = api_url
|
||||
self.api_version = api_version
|
||||
self.deployment_name = deployment_name
|
||||
self.reduced_dimension = reduced_dimension
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_name=model_name, provider_type=provider_type
|
||||
)
|
||||
@@ -188,6 +190,7 @@ class EmbeddingModel:
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
api_url=self.api_url,
|
||||
reduced_dimension=self.reduced_dimension,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
@@ -300,6 +303,7 @@ class EmbeddingModel:
|
||||
retrim_content=retrim_content,
|
||||
api_version=search_settings.api_version,
|
||||
deployment_name=search_settings.deployment_name,
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.onyxbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.chat import get_chat_session_by_message_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ChannelConfig
|
||||
from onyx.onyxbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
@@ -31,12 +31,18 @@ from onyx.onyxbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID
|
||||
from onyx.onyxbot.slack.formatting import format_slack_message
|
||||
from onyx.onyxbot.slack.icons import source_to_github_img_link
|
||||
from onyx.onyxbot.slack.models import ActionValuesEphemeralMessage
|
||||
from onyx.onyxbot.slack.models import ActionValuesEphemeralMessageChannelConfig
|
||||
from onyx.onyxbot.slack.models import ActionValuesEphemeralMessageMessageInfo
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.utils import build_continue_in_web_ui_id
|
||||
from onyx.onyxbot.slack.utils import build_feedback_id
|
||||
from onyx.onyxbot.slack.utils import build_publish_ephemeral_message_id
|
||||
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
|
||||
from onyx.onyxbot.slack.utils import translate_vespa_highlight_to_slack
|
||||
from onyx.utils.text_processing import decode_escapes
|
||||
@@ -105,6 +111,77 @@ def _build_qa_feedback_block(
|
||||
)
|
||||
|
||||
|
||||
def _build_ephemeral_publication_block(
|
||||
channel_id: str,
|
||||
chat_message_id: int,
|
||||
message_info: SlackMessageInfo,
|
||||
original_question_ts: str,
|
||||
channel_conf: ChannelConfig,
|
||||
feedback_reminder_id: str | None = None,
|
||||
) -> Block:
|
||||
# check whether the message is in a thread
|
||||
if (
|
||||
message_info is not None
|
||||
and message_info.msg_to_respond is not None
|
||||
and message_info.thread_to_respond is not None
|
||||
and (message_info.msg_to_respond == message_info.thread_to_respond)
|
||||
):
|
||||
respond_ts = None
|
||||
else:
|
||||
respond_ts = original_question_ts
|
||||
|
||||
action_values_ephemeral_message_channel_config = (
|
||||
ActionValuesEphemeralMessageChannelConfig(
|
||||
channel_name=channel_conf.get("channel_name"),
|
||||
respond_tag_only=channel_conf.get("respond_tag_only"),
|
||||
respond_to_bots=channel_conf.get("respond_to_bots"),
|
||||
is_ephemeral=channel_conf.get("is_ephemeral", False),
|
||||
respond_member_group_list=channel_conf.get("respond_member_group_list"),
|
||||
answer_filters=channel_conf.get("answer_filters"),
|
||||
follow_up_tags=channel_conf.get("follow_up_tags"),
|
||||
show_continue_in_web_ui=channel_conf.get("show_continue_in_web_ui", False),
|
||||
)
|
||||
)
|
||||
|
||||
action_values_ephemeral_message_message_info = (
|
||||
ActionValuesEphemeralMessageMessageInfo(
|
||||
bypass_filters=message_info.bypass_filters,
|
||||
channel_to_respond=message_info.channel_to_respond,
|
||||
msg_to_respond=message_info.msg_to_respond,
|
||||
email=message_info.email,
|
||||
sender_id=message_info.sender_id,
|
||||
thread_messages=[],
|
||||
is_bot_msg=message_info.is_bot_msg,
|
||||
is_bot_dm=message_info.is_bot_dm,
|
||||
thread_to_respond=respond_ts,
|
||||
)
|
||||
)
|
||||
|
||||
action_values_ephemeral_message = ActionValuesEphemeralMessage(
|
||||
original_question_ts=original_question_ts,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
chat_message_id=chat_message_id,
|
||||
message_info=action_values_ephemeral_message_message_info,
|
||||
channel_conf=action_values_ephemeral_message_channel_config,
|
||||
)
|
||||
|
||||
return ActionsBlock(
|
||||
block_id=build_publish_ephemeral_message_id(original_question_ts),
|
||||
elements=[
|
||||
ButtonElement(
|
||||
action_id=SHOW_EVERYONE_ACTION_ID,
|
||||
text="📢 Share with Everyone",
|
||||
value=action_values_ephemeral_message.model_dump_json(),
|
||||
),
|
||||
ButtonElement(
|
||||
action_id=KEEP_TO_YOURSELF_ACTION_ID,
|
||||
text="🤫 Keep to Yourself",
|
||||
value=action_values_ephemeral_message.model_dump_json(),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def get_document_feedback_blocks() -> Block:
|
||||
return SectionBlock(
|
||||
text=(
|
||||
@@ -410,12 +487,11 @@ def _build_qa_response_blocks(
|
||||
|
||||
|
||||
def _build_continue_in_web_ui_block(
|
||||
tenant_id: str,
|
||||
message_id: int | None,
|
||||
) -> Block:
|
||||
if message_id is None:
|
||||
raise ValueError("No message id provided to build continue in web ui block")
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
chat_session = get_chat_session_by_message_id(
|
||||
db_session=db_session,
|
||||
message_id=message_id,
|
||||
@@ -482,22 +558,26 @@ def build_follow_up_resolved_blocks(
|
||||
|
||||
def build_slack_response_blocks(
|
||||
answer: ChatOnyxBotResponse,
|
||||
tenant_id: str,
|
||||
message_info: SlackMessageInfo,
|
||||
channel_conf: ChannelConfig | None,
|
||||
use_citations: bool,
|
||||
feedback_reminder_id: str | None,
|
||||
skip_ai_feedback: bool = False,
|
||||
offer_ephemeral_publication: bool = False,
|
||||
expecting_search_result: bool = False,
|
||||
skip_restated_question: bool = False,
|
||||
) -> list[Block]:
|
||||
"""
|
||||
This function is a top level function that builds all the blocks for the Slack response.
|
||||
It also handles combining all the blocks together.
|
||||
"""
|
||||
# If called with the OnyxBot slash command, the question is lost so we have to reshow it
|
||||
restate_question_block = get_restate_blocks(
|
||||
message_info.thread_messages[-1].message, message_info.is_bot_msg
|
||||
)
|
||||
if not skip_restated_question:
|
||||
restate_question_block = get_restate_blocks(
|
||||
message_info.thread_messages[-1].message, message_info.is_bot_msg
|
||||
)
|
||||
else:
|
||||
restate_question_block = []
|
||||
|
||||
if expecting_search_result:
|
||||
answer_blocks = _build_qa_response_blocks(
|
||||
@@ -517,18 +597,41 @@ def build_slack_response_blocks(
|
||||
if channel_conf and channel_conf.get("show_continue_in_web_ui"):
|
||||
web_follow_up_block.append(
|
||||
_build_continue_in_web_ui_block(
|
||||
tenant_id=tenant_id,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
)
|
||||
|
||||
follow_up_block = []
|
||||
if channel_conf and channel_conf.get("follow_up_tags") is not None:
|
||||
if (
|
||||
channel_conf
|
||||
and channel_conf.get("follow_up_tags") is not None
|
||||
and not channel_conf.get("is_ephemeral", False)
|
||||
):
|
||||
follow_up_block.append(
|
||||
_build_follow_up_block(message_id=answer.chat_message_id)
|
||||
)
|
||||
|
||||
ai_feedback_block = []
|
||||
publish_ephemeral_message_block = []
|
||||
|
||||
if (
|
||||
offer_ephemeral_publication
|
||||
and answer.chat_message_id is not None
|
||||
and message_info.msg_to_respond is not None
|
||||
and channel_conf is not None
|
||||
):
|
||||
publish_ephemeral_message_block.append(
|
||||
_build_ephemeral_publication_block(
|
||||
channel_id=message_info.channel_to_respond,
|
||||
chat_message_id=answer.chat_message_id,
|
||||
original_question_ts=message_info.msg_to_respond,
|
||||
message_info=message_info,
|
||||
channel_conf=channel_conf,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
)
|
||||
|
||||
ai_feedback_block: list[Block] = []
|
||||
|
||||
if answer.chat_message_id is not None and not skip_ai_feedback:
|
||||
ai_feedback_block.append(
|
||||
_build_qa_feedback_block(
|
||||
@@ -550,6 +653,7 @@ def build_slack_response_blocks(
|
||||
all_blocks = (
|
||||
restate_question_block
|
||||
+ answer_blocks
|
||||
+ publish_ephemeral_message_block
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
|
||||
@@ -2,6 +2,8 @@ from enum import Enum
|
||||
|
||||
LIKE_BLOCK_ACTION_ID = "feedback-like"
|
||||
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
|
||||
SHOW_EVERYONE_ACTION_ID = "show-everyone"
|
||||
KEEP_TO_YOURSELF_ACTION_ID = "keep-to-yourself"
|
||||
CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui"
|
||||
FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
|
||||
IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -5,21 +6,32 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
from slack_sdk.models.views import View
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.webhook import WebhookClient
|
||||
|
||||
from onyx.chat.models import ChatOnyxBotResponse
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.chat import get_chat_message
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
from onyx.db.feedback import create_doc_retrieval_feedback
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.onyxbot.slack.blocks import build_follow_up_resolved_blocks
|
||||
from onyx.onyxbot.slack.blocks import build_slack_response_blocks
|
||||
from onyx.onyxbot.slack.blocks import get_document_feedback_blocks
|
||||
from onyx.onyxbot.slack.config import get_slack_channel_config_for_bot_and_channel
|
||||
from onyx.onyxbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import FeedbackVisibility
|
||||
from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import VIEW_DOC_FEEDBACK_ID
|
||||
from onyx.onyxbot.slack.handlers.handle_message import (
|
||||
remove_scheduled_feedback_reminder,
|
||||
@@ -35,15 +47,48 @@ from onyx.onyxbot.slack.utils import fetch_slack_user_ids_from_emails
|
||||
from onyx.onyxbot.slack.utils import get_channel_name_from_id
|
||||
from onyx.onyxbot.slack.utils import get_feedback_visibility
|
||||
from onyx.onyxbot.slack.utils import read_slack_thread
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import TenantSocketModeClient
|
||||
from onyx.onyxbot.slack.utils import update_emote_react
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _convert_db_doc_id_to_document_ids(
|
||||
citation_dict: dict[int, int], top_documents: list[SavedSearchDoc]
|
||||
) -> list[CitationInfo]:
|
||||
citation_list_with_document_id = []
|
||||
for citation_num, db_doc_id in citation_dict.items():
|
||||
if db_doc_id is not None:
|
||||
matching_doc = next(
|
||||
(d for d in top_documents if d.db_doc_id == db_doc_id), None
|
||||
)
|
||||
if matching_doc:
|
||||
citation_list_with_document_id.append(
|
||||
CitationInfo(
|
||||
citation_num=citation_num, document_id=matching_doc.document_id
|
||||
)
|
||||
)
|
||||
return citation_list_with_document_id
|
||||
|
||||
|
||||
def _build_citation_list(chat_message_detail: ChatMessageDetail) -> list[CitationInfo]:
|
||||
citation_dict = chat_message_detail.citations
|
||||
if citation_dict is None:
|
||||
return []
|
||||
else:
|
||||
top_documents = (
|
||||
chat_message_detail.context_docs.top_documents
|
||||
if chat_message_detail.context_docs
|
||||
else []
|
||||
)
|
||||
citation_list = _convert_db_doc_id_to_document_ids(citation_dict, top_documents)
|
||||
return citation_list
|
||||
|
||||
|
||||
def handle_doc_feedback_button(
|
||||
req: SocketModeRequest,
|
||||
client: TenantSocketModeClient,
|
||||
@@ -58,7 +103,7 @@ def handle_doc_feedback_button(
|
||||
external_id = build_feedback_id(query_event_id, doc_id, doc_rank)
|
||||
|
||||
channel_id = req.payload["container"]["channel_id"]
|
||||
thread_ts = req.payload["container"]["thread_ts"]
|
||||
thread_ts = req.payload["container"].get("thread_ts", None)
|
||||
|
||||
data = View(
|
||||
type="modal",
|
||||
@@ -84,7 +129,7 @@ def handle_generate_answer_button(
|
||||
channel_id = req.payload["channel"]["id"]
|
||||
channel_name = req.payload["channel"]["name"]
|
||||
message_ts = req.payload["message"]["ts"]
|
||||
thread_ts = req.payload["container"]["thread_ts"]
|
||||
thread_ts = req.payload["container"].get("thread_ts", None)
|
||||
user_id = req.payload["user"]["id"]
|
||||
expert_info = expert_info_from_slack_id(user_id, client.web_client, user_cache={})
|
||||
email = expert_info.email if expert_info else None
|
||||
@@ -106,7 +151,7 @@ def handle_generate_answer_button(
|
||||
|
||||
# tell the user that we're working on it
|
||||
# Send an ephemeral message to the user that we're generating the answer
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client.web_client,
|
||||
channel=channel_id,
|
||||
receiver_ids=[user_id],
|
||||
@@ -114,7 +159,7 @@ def handle_generate_answer_button(
|
||||
thread_ts=thread_ts,
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
@@ -136,13 +181,184 @@ def handle_generate_answer_button(
|
||||
slack_channel_config=slack_channel_config,
|
||||
receiver_ids=None,
|
||||
client=client.web_client,
|
||||
tenant_id=client.tenant_id,
|
||||
channel=channel_id,
|
||||
logger=logger,
|
||||
feedback_reminder_id=None,
|
||||
)
|
||||
|
||||
|
||||
def handle_publish_ephemeral_message_button(
|
||||
req: SocketModeRequest,
|
||||
client: TenantSocketModeClient,
|
||||
action_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
This function handles the Share with Everyone/Keep for Yourself buttons
|
||||
for ephemeral messages.
|
||||
"""
|
||||
channel_id = req.payload["channel"]["id"]
|
||||
ephemeral_message_ts = req.payload["container"]["message_ts"]
|
||||
|
||||
slack_sender_id = req.payload["user"]["id"]
|
||||
response_url = req.payload["response_url"]
|
||||
webhook = WebhookClient(url=response_url)
|
||||
|
||||
# The additional data required that was added to buttons.
|
||||
# Specifically, this contains the message_info, channel_conf information
|
||||
# and some additional attributes.
|
||||
value_dict = json.loads(req.payload["actions"][0]["value"])
|
||||
|
||||
original_question_ts = value_dict.get("original_question_ts")
|
||||
if not original_question_ts:
|
||||
raise ValueError("Missing original_question_ts in the payload")
|
||||
if not ephemeral_message_ts:
|
||||
raise ValueError("Missing ephemeral_message_ts in the payload")
|
||||
|
||||
feedback_reminder_id = value_dict.get("feedback_reminder_id")
|
||||
|
||||
slack_message_info = SlackMessageInfo(**value_dict["message_info"])
|
||||
channel_conf = value_dict.get("channel_conf")
|
||||
|
||||
user_email = value_dict.get("message_info", {}).get("email")
|
||||
|
||||
chat_message_id = value_dict.get("chat_message_id")
|
||||
|
||||
# Obtain onyx_user and chat_message information
|
||||
if not chat_message_id:
|
||||
raise ValueError("Missing chat_message_id in the payload")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
onyx_user = get_user_by_email(user_email, db_session)
|
||||
if not onyx_user:
|
||||
raise ValueError("Cannot determine onyx_user_id from email in payload")
|
||||
try:
|
||||
chat_message = get_chat_message(chat_message_id, onyx_user.id, db_session)
|
||||
except ValueError:
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id, None, db_session
|
||||
) # is this good idea?
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get chat message: {e}")
|
||||
raise e
|
||||
|
||||
chat_message_detail = translate_db_message_to_chat_message_detail(chat_message)
|
||||
|
||||
# construct the proper citation format and then the answer in the suitable format
|
||||
# we need to construct the blocks.
|
||||
citation_list = _build_citation_list(chat_message_detail)
|
||||
|
||||
onyx_bot_answer = ChatOnyxBotResponse(
|
||||
answer=chat_message_detail.message,
|
||||
citations=citation_list,
|
||||
chat_message_id=chat_message_id,
|
||||
docs=QADocsResponse(
|
||||
top_documents=chat_message_detail.context_docs.top_documents
|
||||
if chat_message_detail.context_docs
|
||||
else [],
|
||||
predicted_flow=None,
|
||||
predicted_search=None,
|
||||
applied_source_filters=None,
|
||||
applied_time_cutoff=None,
|
||||
recency_bias_multiplier=1.0,
|
||||
),
|
||||
llm_selected_doc_indices=None,
|
||||
error_msg=None,
|
||||
)
|
||||
|
||||
# Note: we need to use the webhook and the respond_url to update/delete ephemeral messages
|
||||
if action_id == SHOW_EVERYONE_ACTION_ID:
|
||||
# Convert to non-ephemeral message in thread
|
||||
try:
|
||||
webhook.send(
|
||||
response_type="ephemeral",
|
||||
text="",
|
||||
blocks=[],
|
||||
replace_original=True,
|
||||
delete_original=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send webhook: {e}")
|
||||
|
||||
# remove handling of empheremal block and add AI feedback.
|
||||
all_blocks = build_slack_response_blocks(
|
||||
answer=onyx_bot_answer,
|
||||
message_info=slack_message_info,
|
||||
channel_conf=channel_conf,
|
||||
use_citations=True,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
skip_ai_feedback=False,
|
||||
offer_ephemeral_publication=False,
|
||||
skip_restated_question=True,
|
||||
)
|
||||
try:
|
||||
# Post in thread as non-ephemeral message
|
||||
respond_in_thread_or_channel(
|
||||
client=client.web_client,
|
||||
channel=channel_id,
|
||||
receiver_ids=None, # If respond_member_group_list is set, send to them. TODO: check!
|
||||
text="Hello! Onyx has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=original_question_ts,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
send_as_ephemeral=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to publish ephemeral message: {e}")
|
||||
raise e
|
||||
|
||||
elif action_id == KEEP_TO_YOURSELF_ACTION_ID:
|
||||
# Keep as ephemeral message in channel or thread, but remove the publish button and add feedback button
|
||||
|
||||
changed_blocks = build_slack_response_blocks(
|
||||
answer=onyx_bot_answer,
|
||||
message_info=slack_message_info,
|
||||
channel_conf=channel_conf,
|
||||
use_citations=True,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
skip_ai_feedback=False,
|
||||
offer_ephemeral_publication=False,
|
||||
skip_restated_question=True,
|
||||
)
|
||||
|
||||
try:
|
||||
if slack_message_info.thread_to_respond is not None:
|
||||
# There seems to be a bug in slack where an update within the thread
|
||||
# actually leads to the update to be posted in the channel. Therefore,
|
||||
# for now we delete the original ephemeral message and post a new one
|
||||
# if the ephemeral message is in a thread.
|
||||
webhook.send(
|
||||
response_type="ephemeral",
|
||||
text="",
|
||||
blocks=[],
|
||||
replace_original=True,
|
||||
delete_original=True,
|
||||
)
|
||||
|
||||
respond_in_thread_or_channel(
|
||||
client=client.web_client,
|
||||
channel=channel_id,
|
||||
receiver_ids=[slack_sender_id],
|
||||
text="Your personal response, sent as an ephemeral message.",
|
||||
blocks=changed_blocks,
|
||||
thread_ts=original_question_ts,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
send_as_ephemeral=True,
|
||||
)
|
||||
else:
|
||||
# This works fine if the ephemeral message is in the channel
|
||||
webhook.send(
|
||||
response_type="ephemeral",
|
||||
text="Your personal response, sent as an ephemeral message.",
|
||||
blocks=changed_blocks,
|
||||
replace_original=True,
|
||||
delete_original=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send webhook: {e}")
|
||||
|
||||
|
||||
def handle_slack_feedback(
|
||||
feedback_id: str,
|
||||
feedback_type: str,
|
||||
@@ -151,17 +367,23 @@ def handle_slack_feedback(
|
||||
user_id_to_post_confirmation: str,
|
||||
channel_id_to_post_confirmation: str,
|
||||
thread_ts_to_post_confirmation: str,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
# Get Onyx user from Slack ID
|
||||
expert_info = expert_info_from_slack_id(
|
||||
user_id_to_post_confirmation, client, user_cache={}
|
||||
)
|
||||
email = expert_info.email if expert_info else None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
onyx_user = get_user_by_email(email, db_session) if email else None
|
||||
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
|
||||
create_chat_message_feedback(
|
||||
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
|
||||
feedback_text="",
|
||||
chat_message_id=message_id,
|
||||
user_id=None, # no "user" for Slack bot for now
|
||||
user_id=onyx_user.id if onyx_user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
remove_scheduled_feedback_reminder(
|
||||
@@ -215,7 +437,7 @@ def handle_slack_feedback(
|
||||
else:
|
||||
msg = f"<@{user_id_to_post_confirmation}> has {feedback_response_txt} the AI Answer"
|
||||
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel_id_to_post_confirmation,
|
||||
text=msg,
|
||||
@@ -234,7 +456,7 @@ def handle_followup_button(
|
||||
action_id = cast(str, action.get("block_id"))
|
||||
|
||||
channel_id = req.payload["container"]["channel_id"]
|
||||
thread_ts = req.payload["container"]["thread_ts"]
|
||||
thread_ts = req.payload["container"].get("thread_ts", None)
|
||||
|
||||
update_emote_react(
|
||||
emoji=DANSWER_FOLLOWUP_EMOJI,
|
||||
@@ -246,7 +468,7 @@ def handle_followup_button(
|
||||
|
||||
tag_ids: list[str] = []
|
||||
group_ids: list[str] = []
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
channel_name, is_dm = get_channel_name_from_id(
|
||||
client=client.web_client, channel_id=channel_id
|
||||
)
|
||||
@@ -267,7 +489,7 @@ def handle_followup_button(
|
||||
|
||||
blocks = build_follow_up_resolved_blocks(tag_ids=tag_ids, group_ids=group_ids)
|
||||
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client.web_client,
|
||||
channel=channel_id,
|
||||
text="Received your request for more help",
|
||||
@@ -317,7 +539,7 @@ def handle_followup_resolved_button(
|
||||
) -> None:
|
||||
channel_id = req.payload["container"]["channel_id"]
|
||||
message_ts = req.payload["container"]["message_ts"]
|
||||
thread_ts = req.payload["container"]["thread_ts"]
|
||||
thread_ts = req.payload["container"].get("thread_ts", None)
|
||||
|
||||
clicker_name = get_clicker_name(req, client)
|
||||
|
||||
@@ -351,7 +573,7 @@ def handle_followup_resolved_button(
|
||||
|
||||
resolved_block = SectionBlock(text=msg_text)
|
||||
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client.web_client,
|
||||
channel=channel_id,
|
||||
text="Your request for help as been addressed!",
|
||||
|
||||
@@ -5,7 +5,7 @@ from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.configs.onyxbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
|
||||
from onyx.configs.onyxbot_configs import DANSWER_REACT_EMOJI
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.onyxbot.slack.blocks import get_feedback_reminder_blocks
|
||||
@@ -18,7 +18,7 @@ from onyx.onyxbot.slack.handlers.handle_standard_answers import (
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.utils import fetch_slack_user_ids_from_emails
|
||||
from onyx.onyxbot.slack.utils import fetch_user_ids_from_groups
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import slack_usage_report
|
||||
from onyx.onyxbot.slack.utils import update_emote_react
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -29,7 +29,7 @@ logger_base = setup_logger()
|
||||
|
||||
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
|
||||
if details.is_bot_msg and details.sender_id:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=details.channel_to_respond,
|
||||
thread_ts=details.msg_to_respond,
|
||||
@@ -109,7 +109,6 @@ def handle_message(
|
||||
slack_channel_config: SlackChannelConfig,
|
||||
client: WebClient,
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> bool:
|
||||
"""Potentially respond to the user message depending on filters and if an answer was generated
|
||||
|
||||
@@ -135,9 +134,7 @@ def handle_message(
|
||||
action = "slack_tag_message"
|
||||
elif is_bot_dm:
|
||||
action = "slack_dm_message"
|
||||
slack_usage_report(
|
||||
action=action, sender_id=sender_id, client=client, tenant_id=tenant_id
|
||||
)
|
||||
slack_usage_report(action=action, sender_id=sender_id, client=client)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
persona = slack_channel_config.persona if slack_channel_config else None
|
||||
@@ -205,7 +202,7 @@ def handle_message(
|
||||
# which would just respond to the sender
|
||||
if send_to and is_bot_msg:
|
||||
if sender_id:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=[sender_id],
|
||||
@@ -218,11 +215,12 @@ def handle_message(
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Was not able to react to user message due to: {e}")
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if message_info.email:
|
||||
add_slack_user_if_not_exists(db_session, message_info.email)
|
||||
|
||||
# first check if we need to respond with a standard answer
|
||||
# standard answers should be published in a thread
|
||||
used_standard_answer = handle_standard_answers(
|
||||
message_info=message_info,
|
||||
receiver_ids=send_to,
|
||||
@@ -244,6 +242,5 @@ def handle_message(
|
||||
channel=channel,
|
||||
logger=logger,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
return issue_with_regular_answer
|
||||
|
||||
@@ -24,7 +24,6 @@ from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
@@ -34,7 +33,7 @@ from onyx.onyxbot.slack.blocks import build_slack_response_blocks
|
||||
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
|
||||
from onyx.onyxbot.slack.handlers.utils import slackify_message_thread
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import SlackRateLimiter
|
||||
from onyx.onyxbot.slack.utils import update_emote_react
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
@@ -72,7 +71,6 @@ def handle_regular_answer(
|
||||
channel: str,
|
||||
logger: OnyxLoggingAdapter,
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
@@ -84,19 +82,45 @@ def handle_regular_answer(
|
||||
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
|
||||
# Capture whether response mode for channel is ephemeral. Even if the channel is set
|
||||
# to respond with an ephemeral message, we still send as non-ephemeral if
|
||||
# the message is a dm with the Onyx bot.
|
||||
send_as_ephemeral = (
|
||||
slack_channel_config.channel_config.get("is_ephemeral", False)
|
||||
and not message_info.is_bot_dm
|
||||
)
|
||||
|
||||
# If the channel mis configured to respond with an ephemeral message,
|
||||
# or the message is a dm to the Onyx bot, we should use the proper onyx user from the email.
|
||||
# This will make documents privately accessible to the user available to Onyx Bot answers.
|
||||
# Otherwise - if not ephemeral or DM to Onyx Bot - we must use None as the user to restrict
|
||||
# to public docs.
|
||||
|
||||
user = None
|
||||
if message_info.is_bot_dm:
|
||||
if message_info.is_bot_dm or send_as_ephemeral:
|
||||
if message_info.email:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user = get_user_by_email(message_info.email, db_session)
|
||||
|
||||
target_thread_ts = (
|
||||
None
|
||||
if send_as_ephemeral and len(message_info.thread_messages) < 2
|
||||
else message_ts_to_respond_to
|
||||
)
|
||||
target_receiver_ids = (
|
||||
[message_info.sender_id]
|
||||
if message_info.sender_id and send_as_ephemeral
|
||||
else receiver_ids
|
||||
)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
prompt = None
|
||||
# If no persona is specified, use the default search based persona
|
||||
# This way slack flow always has a persona
|
||||
persona = slack_channel_config.persona
|
||||
if not persona:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
@@ -136,11 +160,10 @@ def handle_regular_answer(
|
||||
history_messages = messages[:-1]
|
||||
single_message_history = slackify_message_thread(history_messages) or None
|
||||
|
||||
# Always check for ACL permissions, also for documnt sets that were explicitly added
|
||||
# to the Bot by the Administrator. (Change relative to earlier behavior where all documents
|
||||
# in an attached document set were available to all users in the channel.)
|
||||
bypass_acl = False
|
||||
if slack_channel_config.persona and slack_channel_config.persona.document_sets:
|
||||
# For Slack channels, use the full document set, admin will be warned when configuring it
|
||||
# with non-public document sets
|
||||
bypass_acl = True
|
||||
|
||||
if not message_ts_to_respond_to and not is_bot_msg:
|
||||
# if the message is not "/onyx" command, then it should have a message ts to respond to
|
||||
@@ -157,7 +180,7 @@ def handle_regular_answer(
|
||||
def _get_slack_answer(
|
||||
new_message_request: CreateChatMessageRequest, onyx_user: User | None
|
||||
) -> ChatOnyxBotResponse:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=new_message_request,
|
||||
user=onyx_user,
|
||||
@@ -197,7 +220,7 @@ def handle_regular_answer(
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
answer_request = prepare_chat_message_request(
|
||||
message_text=user_message.message,
|
||||
user=user,
|
||||
@@ -221,12 +244,13 @@ def handle_regular_answer(
|
||||
# Optionally, respond in thread with the error message, Used primarily
|
||||
# for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
receiver_ids=target_receiver_ids,
|
||||
text=f"Encountered exception when trying to answer: \n\n```{e}```",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
thread_ts=target_thread_ts,
|
||||
send_as_ephemeral=send_as_ephemeral,
|
||||
)
|
||||
|
||||
# In case of failures, don't keep the reaction there permanently
|
||||
@@ -244,32 +268,36 @@ def handle_regular_answer(
|
||||
if answer is None:
|
||||
assert DISABLE_GENERATIVE_AI is True
|
||||
try:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=receiver_ids,
|
||||
receiver_ids=target_receiver_ids,
|
||||
text="Hello! Onyx has some results for you!",
|
||||
blocks=[
|
||||
SectionBlock(
|
||||
text="Onyx is down for maintenance.\nWe're working hard on recharging the AI!"
|
||||
)
|
||||
],
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
thread_ts=target_thread_ts,
|
||||
send_as_ephemeral=send_as_ephemeral,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
)
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
if receiver_ids:
|
||||
respond_in_thread(
|
||||
|
||||
# If the channel is ephemeral, we don't need to send a message to the user since they will already see the message
|
||||
if target_receiver_ids and not send_as_ephemeral:
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
text=(
|
||||
"👋 Hi, we've just gathered and forwarded the relevant "
|
||||
+ "information to the team. They'll get back to you shortly!"
|
||||
),
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
thread_ts=target_thread_ts,
|
||||
send_as_ephemeral=send_as_ephemeral,
|
||||
)
|
||||
|
||||
return False
|
||||
@@ -318,12 +346,13 @@ def handle_regular_answer(
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
receiver_ids=target_receiver_ids,
|
||||
text="Found no documents when trying to answer. Did you index any documents?",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
thread_ts=target_thread_ts,
|
||||
send_as_ephemeral=send_as_ephemeral,
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -351,48 +380,67 @@ def handle_regular_answer(
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
receiver_ids=target_receiver_ids,
|
||||
text="Found no citations or quotes when trying to answer.",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
thread_ts=target_thread_ts,
|
||||
send_as_ephemeral=send_as_ephemeral,
|
||||
)
|
||||
return True
|
||||
|
||||
if (
|
||||
send_as_ephemeral
|
||||
and target_receiver_ids is not None
|
||||
and len(target_receiver_ids) == 1
|
||||
):
|
||||
offer_ephemeral_publication = True
|
||||
skip_ai_feedback = True
|
||||
else:
|
||||
offer_ephemeral_publication = False
|
||||
skip_ai_feedback = False if feedback_reminder_id else True
|
||||
|
||||
all_blocks = build_slack_response_blocks(
|
||||
tenant_id=tenant_id,
|
||||
message_info=message_info,
|
||||
answer=answer,
|
||||
channel_conf=channel_conf,
|
||||
use_citations=True, # No longer supporting quotes
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
expecting_search_result=expecting_search_result,
|
||||
offer_ephemeral_publication=offer_ephemeral_publication,
|
||||
skip_ai_feedback=skip_ai_feedback,
|
||||
)
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=[message_info.sender_id]
|
||||
if message_info.is_bot_msg and message_info.sender_id
|
||||
else receiver_ids,
|
||||
receiver_ids=target_receiver_ids,
|
||||
text="Hello! Onyx has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
thread_ts=target_thread_ts,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
send_as_ephemeral=send_as_ephemeral,
|
||||
)
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
# if there is no message_ts_to_respond_to, and we have made it this far, then this is a /onyx message
|
||||
# so we shouldn't send_team_member_message
|
||||
if receiver_ids and message_ts_to_respond_to is not None:
|
||||
if (
|
||||
target_receiver_ids
|
||||
and message_ts_to_respond_to is not None
|
||||
and not send_as_ephemeral
|
||||
and target_thread_ts is not None
|
||||
):
|
||||
send_team_member_message(
|
||||
client=client,
|
||||
channel=channel,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
thread_ts=target_thread_ts,
|
||||
receiver_ids=target_receiver_ids,
|
||||
send_as_ephemeral=send_as_ephemeral,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user