Compare commits

..

1 Commits

Author SHA1 Message Date
Yuhong Sun
4293543a6a k 2024-07-20 16:48:05 -07:00
676 changed files with 24100 additions and 36642 deletions

View File

@@ -1,25 +0,0 @@
## Description
[Provide a brief description of the changes in this PR]
## How Has This Been Tested?
[Describe the tests you ran to verify your changes]
## Accepted Risk
[Any know risks or failure modes to point out to reviewers]
## Related Issue(s)
[If applicable, link to the issue(s) this PR addresses]
## Checklist:
- [ ] All of the automated tests pass
- [ ] All PR comments are addressed and marked resolved
- [ ] If there are migrations, they have been rebased to latest main
- [ ] If there are new dependencies, they are added to the requirements
- [ ] If there are new environment variables, they are added to all of the deployment methods
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
- [ ] Docker images build and basic functionalities work
- [ ] Author has done a final read through of the PR right before merge

View File

@@ -0,0 +1,33 @@
name: Build Backend Image on Merge Group
on:
merge_group:
types: [checks_requested]
env:
REGISTRY_IMAGE: danswer/danswer-backend
jobs:
build:
# TODO: make this a matrix build like the web containers
runs-on:
group: amd64-image-builders
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Backend Image Docker Build
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: false
tags: |
${{ env.REGISTRY_IMAGE }}:latest
build-args: |
DANSWER_VERSION=v0.0.1

View File

@@ -7,8 +7,7 @@ on:
jobs:
build-and-push:
runs-on:
group: amd64-image-builders
runs-on: ubuntu-latest
steps:
- name: Checkout code

View File

@@ -0,0 +1,53 @@
name: Build Web Image on Merge Group
on:
merge_group:
types: [checks_requested]
env:
REGISTRY_IMAGE: danswer/danswer-web-server
jobs:
build:
runs-on:
group: ${{ matrix.platform == 'linux/amd64' && 'amd64-image-builders' || 'arm64-image-builders' }}
strategy:
fail-fast: false
matrix:
platform:
- linux/amd64
- linux/arm64
steps:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:latest
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build by digest
id: build
uses: docker/build-push-action@v5
with:
context: ./web
file: ./web/Dockerfile
platforms: ${{ matrix.platform }}
push: false
build-args: |
DANSWER_VERSION=v0.0.1
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -1,7 +1,6 @@
name: Python Checks
on:
merge_group:
pull_request:
branches: [ main ]

View File

@@ -1,7 +1,6 @@
name: Python Unit Tests
on:
merge_group:
pull_request:
branches: [ main ]

View File

@@ -4,19 +4,18 @@ concurrency:
cancel-in-progress: true
on:
merge_group:
pull_request: null
jobs:
quality-checks:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.0
with:
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- uses: pre-commit/action@v3.0.0
with:
extra_args: --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }}

View File

@@ -1,172 +0,0 @@
name: Run Integration Tests
concurrency:
group: Run-Integration-Tests-${{ github.head_ref }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches: [ main ]
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
jobs:
integration-tests:
runs-on:
group: 'arm64-image-builders'
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build Web Docker image
uses: docker/build-push-action@v5
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/arm64
pull: true
push: true
load: true
tags: danswer/danswer-web-server:it
cache-from: type=registry,ref=danswer/danswer-web-server:it
cache-to: |
type=registry,ref=danswer/danswer-web-server:it,mode=max
type=inline
- name: Build Backend Docker image
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/arm64
pull: true
push: true
load: true
tags: danswer/danswer-backend:it
cache-from: type=registry,ref=danswer/danswer-backend:it
cache-to: |
type=registry,ref=danswer/danswer-backend:it,mode=max
type=inline
- name: Build Model Server Docker image
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/arm64
pull: true
push: true
load: true
tags: danswer/danswer-model-server:it
cache-from: type=registry,ref=danswer/danswer-model-server:it
cache-to: |
type=registry,ref=danswer/danswer-model-server:it,mode=max
type=inline
- name: Build integration test Docker image
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/arm64
pull: true
push: true
load: true
tags: danswer/integration-test-runner:it
cache-from: type=registry,ref=danswer/integration-test-runner:it
cache-to: |
type=registry,ref=danswer/integration-test-runner:it,mode=max
type=inline
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
IMAGE_TAG=it \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run integration tests
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e VESPA_HOST=index \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
danswer/integration-test-runner:it
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v3
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v

View File

@@ -15,7 +15,7 @@ LOG_LEVEL=debug
# This passes top N results to LLM an additional time for reranking prior to answer generation
# This step is quite heavy on token usage so we disable it for dev generally
DISABLE_LLM_DOC_RELEVANCE=True
DISABLE_LLM_CHUNK_FILTER=True
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)

View File

@@ -39,8 +39,7 @@
"--reload",
"--port",
"9000"
],
"consoleTitle": "Model Server"
]
},
{
"name": "API Server",
@@ -59,8 +58,7 @@
"--reload",
"--port",
"8080"
],
"consoleTitle": "API Server"
]
},
{
"name": "Indexing",
@@ -70,12 +68,11 @@
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"ENABLE_MINI_CHUNK": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"consoleTitle": "Indexing"
}
},
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
{
@@ -93,8 +90,7 @@
},
"args": [
"--no-indexing"
],
"consoleTitle": "Background Jobs"
]
},
// For the listner to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
@@ -129,17 +125,5 @@
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
]
}
],
"compounds": [
{
"name": "Run Danswer",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Indexing",
"Background Jobs",
]
}
]
}

View File

@@ -68,9 +68,7 @@ RUN apt-get update && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \

View File

@@ -18,22 +18,14 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \
apt-get autoremove -y
# Pre-downloading models for setups with limited egress
# Download tokenizers, distilbert for the Danswer model
# Download model weights
# Run Nomic to pull in the custom architecture and have it cached locally
RUN python -c "from transformers import AutoTokenizer; \
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
RUN python -c "from transformers import AutoModel, AutoTokenizer, TFDistilBertForSequenceClassification; \
from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from sentence_transformers import SentenceTransformer; \
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
# running Danswer, don't overwrite it with the built in cache folder
RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface
AutoTokenizer.from_pretrained('danswer/intent-model'); \
AutoTokenizer.from_pretrained('intfloat/e5-base-v2'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
snapshot_download('danswer/intent-model'); \
snapshot_download('intfloat/e5-base-v2'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1')"
WORKDIR /app

View File

@@ -1,81 +1,72 @@
import asyncio
from logging.config import fileConfig
from typing import Tuple
from alembic import context
from danswer.db.engine import build_connection_string
from danswer.db.models import Base
from sqlalchemy import pool, text
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore
# Alembic Config object
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# Add your model's MetaData object here
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = [Base.metadata, ResultModelBase.metadata]
def get_schema_options() -> str:
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
for pair in arg.split(','):
if '=' in pair:
key, value = pair.split('=', 1)
x_args[key] = value
schema_name = x_args.get('schema', 'public')
return schema_name
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode."""
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = build_connection_string()
schema = get_schema_options()
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
target_metadata=target_metadata, # type: ignore
literal_binds=True,
dialect_opts={"paramstyle": "named"},
version_table_schema=schema,
include_schemas=True,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
schema = get_schema_options()
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
connection.execute(text('COMMIT'))
connection.execute(text(f'SET search_path TO "{schema}"'))
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
version_table_schema=schema,
include_schemas=True,
compare_type=True,
compare_server_default=True,
)
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
print("Running async migrations")
"""Run migrations in 'online' mode."""
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
@@ -86,10 +77,13 @@ async def run_async_migrations() -> None:
await connectable.dispose()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:

View File

@@ -17,11 +17,15 @@ depends_on: None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"chat_session",
sa.Column("current_alternate_model", sa.String(), nullable=True),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat_session", "current_alternate_model")
# ### end Alembic commands ###

View File

@@ -1,26 +0,0 @@
"""add_indexing_start_to_connector
Revision ID: 08a1eda20fe1
Revises: 8a87bd6ec550
Create Date: 2024-07-23 11:12:39.462397
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "08a1eda20fe1"
down_revision = "8a87bd6ec550"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"connector", sa.Column("indexing_start", sa.DateTime(), nullable=True)
)
def downgrade() -> None:
op.drop_column("connector", "indexing_start")

View File

@@ -9,9 +9,9 @@ from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table
from sqlalchemy.dialects import postgresql
from alembic_utils import encrypt_string
import json
from danswer.utils.encryption import encrypt_string_to_bytes
# revision identifiers, used by Alembic.
revision = "0a98909f2757"
@@ -57,7 +57,7 @@ def upgrade() -> None:
# In other words, this upgrade does not apply the encryption. Porting existing sensitive data
# and key rotation currently is not supported and will come out in the future
for row_id, creds, _ in results:
creds_binary = encrypt_string(json.dumps(creds))
creds_binary = encrypt_string_to_bytes(json.dumps(creds))
connection.execute(
creds_table.update()
.where(creds_table.c.id == row_id)
@@ -86,7 +86,7 @@ def upgrade() -> None:
results = connection.execute(sa.select(llm_table))
for row_id, api_key, _ in results:
llm_key = encrypt_string(api_key)
llm_key = encrypt_string_to_bytes(api_key)
connection.execute(
llm_table.update()
.where(llm_table.c.id == row_id)

View File

@@ -8,7 +8,7 @@ Create Date: 2023-11-11 20:51:24.228999
from alembic import op
import sqlalchemy as sa
from alembic_utils import DocumentSource
from danswer.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
revision = "15326fcec57e"

View File

@@ -1,134 +0,0 @@
"""embedding model -> search settings
Revision ID: 1f60f60c3401
Revises: f17bf3b0d9f1
Create Date: 2024-08-25 12:39:51.731632
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic_utils import NUM_POSTPROCESSED_RESULTS
# revision identifiers, used by Alembic.
revision = "1f60f60c3401"
down_revision = "f17bf3b0d9f1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.drop_constraint(
"index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey"
)
# Rename the table
op.rename_table("embedding_model", "search_settings")
# Add new columns
op.add_column(
"search_settings",
sa.Column(
"multipass_indexing", sa.Boolean(), nullable=False, server_default="true"
),
)
op.add_column(
"search_settings",
sa.Column(
"multilingual_expansion",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
)
op.add_column(
"search_settings",
sa.Column(
"disable_rerank_for_streaming",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
op.add_column(
"search_settings", sa.Column("rerank_model_name", sa.String(), nullable=True)
)
op.add_column(
"search_settings", sa.Column("rerank_provider_type", sa.String(), nullable=True)
)
op.add_column(
"search_settings", sa.Column("rerank_api_key", sa.String(), nullable=True)
)
op.add_column(
"search_settings",
sa.Column(
"num_rerank",
sa.Integer(),
nullable=False,
server_default=str(NUM_POSTPROCESSED_RESULTS),
),
)
# Add the new column as nullable initially
op.add_column(
"index_attempt", sa.Column("search_settings_id", sa.Integer(), nullable=True)
)
# Populate the new column with data from the existing embedding_model_id
op.execute("UPDATE index_attempt SET search_settings_id = embedding_model_id")
# Create the foreign key constraint
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
)
# Make the new column non-nullable
op.alter_column("index_attempt", "search_settings_id", nullable=False)
# Drop the old embedding_model_id column
op.drop_column("index_attempt", "embedding_model_id")
def downgrade() -> None:
# Add back the embedding_model_id column
op.add_column(
"index_attempt", sa.Column("embedding_model_id", sa.Integer(), nullable=True)
)
# Populate the old column with data from search_settings_id
op.execute("UPDATE index_attempt SET embedding_model_id = search_settings_id")
# Make the old column non-nullable
op.alter_column("index_attempt", "embedding_model_id", nullable=False)
# Drop the foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Drop the new search_settings_id column
op.drop_column("index_attempt", "search_settings_id")
# Rename the table back
op.rename_table("search_settings", "embedding_model")
# Remove added columns
op.drop_column("embedding_model", "num_rerank")
op.drop_column("embedding_model", "rerank_api_key")
op.drop_column("embedding_model", "rerank_provider_type")
op.drop_column("embedding_model", "rerank_model_name")
op.drop_column("embedding_model", "disable_rerank_for_streaming")
op.drop_column("embedding_model", "multilingual_expansion")
op.drop_column("embedding_model", "multipass_indexing")
op.create_foreign_key(
"index_attempt__embedding_model_fk",
"index_attempt",
"embedding_model",
["embedding_model_id"],
["id"],
)

View File

@@ -1,44 +0,0 @@
"""notifications
Revision ID: 213fd978c6d8
Revises: 5fc1f54cc252
Create Date: 2024-08-10 11:13:36.070790
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "213fd978c6d8"
down_revision = "5fc1f54cc252"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"notification",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"notif_type",
sa.String(),
nullable=False,
),
sa.Column(
"user_id",
sa.UUID(),
nullable=True,
),
sa.Column("dismissed", sa.Boolean(), nullable=False),
sa.Column("last_shown", sa.DateTime(timezone=True), nullable=False),
sa.Column("first_shown", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("notification")

View File

@@ -79,7 +79,7 @@ def downgrade() -> None:
)
op.create_foreign_key(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval_feedback",
"document_retrieval",
"chat_message",
["chat_message_id"],
["id"],

View File

@@ -160,28 +160,12 @@ def downgrade() -> None:
nullable=False,
),
)
# Check if the constraint exists before dropping
conn = op.get_bind()
inspector = sa.inspect(conn)
constraints = inspector.get_foreign_keys("index_attempt")
if any(
constraint["name"] == "fk_index_attempt_credential_id"
for constraint in constraints
):
op.drop_constraint(
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
)
if any(
constraint["name"] == "fk_index_attempt_connector_id"
for constraint in constraints
):
op.drop_constraint(
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
)
op.drop_constraint(
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
)
op.drop_constraint(
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
)
op.drop_column("index_attempt", "credential_id")
op.drop_column("index_attempt", "connector_id")
op.drop_table("connector_credential_pair")

View File

@@ -1,32 +0,0 @@
"""Add Above Below to Persona
Revision ID: 2d2304e27d8c
Revises: 4b08d97e175a
Create Date: 2024-08-21 19:15:15.762948
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2d2304e27d8c"
down_revision = "4b08d97e175a"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column("persona", sa.Column("chunks_above", sa.Integer(), nullable=True))
op.add_column("persona", sa.Column("chunks_below", sa.Integer(), nullable=True))
op.execute(
"UPDATE persona SET chunks_above = 1, chunks_below = 1 WHERE chunks_above IS NULL AND chunks_below IS NULL"
)
op.alter_column("persona", "chunks_above", nullable=False)
op.alter_column("persona", "chunks_below", nullable=False)
def downgrade() -> None:
op.drop_column("persona", "chunks_below")
op.drop_column("persona", "chunks_above")

View File

@@ -1,70 +0,0 @@
"""Add icon_color and icon_shape to Persona
Revision ID: 325975216eb3
Revises: 91ffac7e65b3
Create Date: 2024-07-24 21:29:31.784562
"""
import random
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, column, select
# revision identifiers, used by Alembic.
revision = "325975216eb3"
down_revision = "91ffac7e65b3"
branch_labels: None = None
depends_on: None = None
colorOptions = [
"#FF6FBF",
"#6FB1FF",
"#B76FFF",
"#FFB56F",
"#6FFF8D",
"#FF6F6F",
"#6FFFFF",
]
# Function to generate a random shape ensuring at least 3 of the middle 4 squares are filled
def generate_random_shape() -> int:
center_squares = [12, 10, 6, 14, 13, 11, 7, 15]
center_fill = random.choice(center_squares)
remaining_squares = [i for i in range(16) if not (center_fill & (1 << i))]
random.shuffle(remaining_squares)
for i in range(10 - bin(center_fill).count("1")):
center_fill |= 1 << remaining_squares[i]
return center_fill
def upgrade() -> None:
op.add_column("persona", sa.Column("icon_color", sa.String(), nullable=True))
op.add_column("persona", sa.Column("icon_shape", sa.Integer(), nullable=True))
op.add_column("persona", sa.Column("uploaded_image_id", sa.String(), nullable=True))
persona = table(
"persona",
column("id", sa.Integer),
column("icon_color", sa.String),
column("icon_shape", sa.Integer),
)
conn = op.get_bind()
personas = conn.execute(select(persona.c.id))
for persona_id in personas:
random_color = random.choice(colorOptions)
random_shape = generate_random_shape()
conn.execute(
persona.update()
.where(persona.c.id == persona_id[0])
.values(icon_color=random_color, icon_shape=random_shape)
)
def downgrade() -> None:
op.drop_column("persona", "icon_shape")
op.drop_column("persona", "uploaded_image_id")
op.drop_column("persona", "icon_color")

View File

@@ -1,90 +0,0 @@
"""Add curator fields
Revision ID: 351faebd379d
Revises: ee3f4b47fad5
Create Date: 2024-08-15 22:37:08.397052
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "351faebd379d"
down_revision = "ee3f4b47fad5"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add is_curator column to User__UserGroup table
op.add_column(
"user__user_group",
sa.Column("is_curator", sa.Boolean(), nullable=False, server_default="false"),
)
# Use batch mode to modify the enum type
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.alter_column( # type: ignore[attr-defined]
"role",
type_=sa.Enum(
"BASIC",
"ADMIN",
"CURATOR",
"GLOBAL_CURATOR",
name="userrole",
native_enum=False,
),
existing_type=sa.Enum("BASIC", "ADMIN", name="userrole", native_enum=False),
existing_nullable=False,
)
# Create the association table
op.create_table(
"credential__user_group",
sa.Column("credential_id", sa.Integer(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["credential_id"],
["credential.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("credential_id", "user_group_id"),
)
op.add_column(
"credential",
sa.Column(
"curator_public", sa.Boolean(), nullable=False, server_default="false"
),
)
def downgrade() -> None:
# Update existing records to ensure they fit within the BASIC/ADMIN roles
op.execute(
"UPDATE \"user\" SET role = 'ADMIN' WHERE role IN ('CURATOR', 'GLOBAL_CURATOR')"
)
# Remove is_curator column from User__UserGroup table
op.drop_column("user__user_group", "is_curator")
with op.batch_alter_table("user", schema=None) as batch_op:
batch_op.alter_column( # type: ignore[attr-defined]
"role",
type_=sa.Enum(
"BASIC", "ADMIN", name="userrole", native_enum=False, length=20
),
existing_type=sa.Enum(
"BASIC",
"ADMIN",
"CURATOR",
"GLOBAL_CURATOR",
name="userrole",
native_enum=False,
),
existing_nullable=False,
)
# Drop the association table
op.drop_table("credential__user_group")
op.drop_column("credential", "curator_public")

View File

@@ -18,6 +18,7 @@ depends_on: None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
)
@@ -28,8 +29,10 @@ def upgrade() -> None:
["alternate_assistant_id"],
["id"],
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey")
op.drop_column("chat_message", "alternate_assistant_id")

View File

@@ -1,42 +0,0 @@
"""Rename index_origin to index_recursively
Revision ID: 1d6ad76d1f37
Revises: e1392f05e840
Create Date: 2024-08-01 12:38:54.466081
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "1d6ad76d1f37"
down_revision = "e1392f05e840"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE connector
SET connector_specific_config = jsonb_set(
connector_specific_config,
'{index_recursively}',
'true'::jsonb
) - 'index_origin'
WHERE connector_specific_config ? 'index_origin'
"""
)
def downgrade() -> None:
op.execute(
"""
UPDATE connector
SET connector_specific_config = jsonb_set(
connector_specific_config,
'{index_origin}',
connector_specific_config->'index_recursively'
) - 'index_recursively'
WHERE connector_specific_config ? 'index_recursively'
"""
)

View File

@@ -1,49 +0,0 @@
"""Add display_model_names to llm_provider
Revision ID: 473a1a7ca408
Revises: 325975216eb3
Create Date: 2024-07-25 14:31:02.002917
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "473a1a7ca408"
down_revision = "325975216eb3"
branch_labels: None = None
depends_on: None = None
default_models_by_provider = {
"openai": ["gpt-4", "gpt-4o", "gpt-4o-mini"],
"bedrock": [
"meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
"mistral.mistral-large-2402-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
],
"anthropic": ["claude-3-opus-20240229", "claude-3-5-sonnet-20240620"],
}
def upgrade() -> None:
op.add_column(
"llm_provider",
sa.Column("display_model_names", postgresql.ARRAY(sa.String()), nullable=True),
)
connection = op.get_bind()
for provider, models in default_models_by_provider.items():
connection.execute(
sa.text(
"UPDATE llm_provider SET display_model_names = :models WHERE provider = :provider"
),
{"models": models, "provider": provider},
)
def downgrade() -> None:
op.drop_column("llm_provider", "display_model_names")

View File

@@ -1,80 +0,0 @@
"""Moved status to connector credential pair
Revision ID: 4a951134c801
Revises: 7477a5f5d728
Create Date: 2024-08-10 19:20:34.527559
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4a951134c801"
down_revision = "7477a5f5d728"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"status",
sa.Enum(
"ACTIVE",
"PAUSED",
"DELETING",
name="connectorcredentialpairstatus",
native_enum=False,
),
nullable=True,
),
)
# Update status of connector_credential_pair based on connector's disabled status
op.execute(
"""
UPDATE connector_credential_pair
SET status = CASE
WHEN (
SELECT disabled
FROM connector
WHERE connector.id = connector_credential_pair.connector_id
) = FALSE THEN 'ACTIVE'
ELSE 'PAUSED'
END
"""
)
# Make the status column not nullable after setting values
op.alter_column("connector_credential_pair", "status", nullable=False)
op.drop_column("connector", "disabled")
def downgrade() -> None:
op.add_column(
"connector",
sa.Column("disabled", sa.BOOLEAN(), autoincrement=False, nullable=True),
)
# Update disabled status of connector based on connector_credential_pair's status
op.execute(
"""
UPDATE connector
SET disabled = CASE
WHEN EXISTS (
SELECT 1
FROM connector_credential_pair
WHERE connector_credential_pair.connector_id = connector.id
AND connector_credential_pair.status = 'ACTIVE'
) THEN FALSE
ELSE TRUE
END
"""
)
# Make the disabled column not nullable after setting values
op.alter_column("connector", "disabled", nullable=False)
op.drop_column("connector_credential_pair", "status")

View File

@@ -1,34 +0,0 @@
"""change default prune_freq
Revision ID: 4b08d97e175a
Revises: d9ec13955951
Create Date: 2024-08-20 15:28:52.993827
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "4b08d97e175a"
down_revision = "d9ec13955951"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE connector
SET prune_freq = 2592000
WHERE prune_freq = 86400
"""
)
def downgrade() -> None:
op.execute(
"""
UPDATE connector
SET prune_freq = 86400
WHERE prune_freq = 2592000
"""
)

View File

@@ -1,72 +0,0 @@
"""Add type to credentials
Revision ID: 4ea2c93919c1
Revises: 473a1a7ca408
Create Date: 2024-07-18 13:07:13.655895
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4ea2c93919c1"
down_revision = "473a1a7ca408"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add the new 'source' column to the 'credential' table
op.add_column(
"credential",
sa.Column(
"source",
sa.String(length=100), # Use String instead of Enum
nullable=True, # Initially allow NULL values
),
)
op.add_column(
"credential",
sa.Column(
"name",
sa.String(),
nullable=True,
),
)
# Create a temporary table that maps each credential to a single connector source.
# This is needed because a credential can be associated with multiple connectors,
# but we want to assign a single source to each credential.
# We use DISTINCT ON to ensure we only get one row per credential_id.
op.execute(
"""
CREATE TEMPORARY TABLE temp_connector_credential AS
SELECT DISTINCT ON (cc.credential_id)
cc.credential_id,
c.source AS connector_source
FROM connector_credential_pair cc
JOIN connector c ON cc.connector_id = c.id
"""
)
# Update the 'source' column in the 'credential' table
op.execute(
"""
UPDATE credential cred
SET source = COALESCE(
(SELECT connector_source
FROM temp_connector_credential temp
WHERE cred.id = temp.credential_id),
'NOT_APPLICABLE'
)
"""
)
# If no exception was raised, alter the column
op.alter_column("credential", "source", nullable=True) # TODO modify
# # ### end Alembic commands ###
def downgrade() -> None:
op.drop_column("credential", "source")
op.drop_column("credential", "name")

View File

@@ -1,25 +0,0 @@
"""hybrid-enum
Revision ID: 5fc1f54cc252
Revises: 1d6ad76d1f37
Create Date: 2024-08-06 15:35:40.278485
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5fc1f54cc252"
down_revision = "1d6ad76d1f37"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.drop_column("persona", "search_type")
def downgrade() -> None:
op.add_column("persona", sa.Column("search_type", sa.String(), nullable=True))
op.execute("UPDATE persona SET search_type = 'SEMANTIC'")
op.alter_column("persona", "search_type", nullable=False)

View File

@@ -5,8 +5,11 @@ Revises: fad14119fb92
Create Date: 2024-04-15 01:36:02.952809
"""
import json
from typing import cast
from alembic import op
import sqlalchemy as sa
from danswer.dynamic_configs.factory import get_dynamic_config_store
# revision identifiers, used by Alembic.
revision = "703313b75876"
@@ -50,6 +53,30 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("rate_limit_id", "user_group_id"),
)
try:
settings_json = cast(
str, get_dynamic_config_store().load("token_budget_settings")
)
settings = json.loads(settings_json)
is_enabled = settings.get("enable_token_budget", False)
token_budget = settings.get("token_budget", -1)
period_hours = settings.get("period_hours", -1)
if is_enabled and token_budget > 0 and period_hours > 0:
op.execute(
f"INSERT INTO token_rate_limit \
(enabled, token_budget, period_hours, scope) VALUES \
({is_enabled}, {token_budget}, {period_hours}, 'GLOBAL')"
)
# Delete the dynamic config
get_dynamic_config_store().delete("token_budget_settings")
except Exception:
# Ignore if the dynamic config is not found
pass
def downgrade() -> None:
op.drop_table("token_rate_limit__user_group")

View File

@@ -1,24 +0,0 @@
"""Added model defaults for users
Revision ID: 7477a5f5d728
Revises: 213fd978c6d8
Create Date: 2024-08-04 19:00:04.512634
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7477a5f5d728"
down_revision = "213fd978c6d8"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column("user", sa.Column("default_model", sa.Text(), nullable=True))
def downgrade() -> None:
op.drop_column("user", "default_model")

View File

@@ -28,9 +28,5 @@ def upgrade() -> None:
def downgrade() -> None:
op.create_unique_constraint(
"connector_credential_pair__name__key", "connector_credential_pair", ["name"]
)
op.alter_column(
"connector_credential_pair", "name", existing_type=sa.String(), nullable=True
)
# This wasn't really required by the code either, no good reason to make it unique again
pass

View File

@@ -7,8 +7,10 @@ Create Date: 2024-03-22 21:34:27.629444
"""
from alembic import op
import sqlalchemy as sa
from alembic_utils import IndexModelStatus, RecencyBiasSetting, SearchType
from danswer.db.models import IndexModelStatus
from danswer.search.enums import RecencyBiasSetting
from danswer.search.models import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"

View File

@@ -1,41 +0,0 @@
"""add_llm_group_permissions_control
Revision ID: 795b20b85b4b
Revises: 05c07bf07c00
Create Date: 2024-07-19 11:54:35.701558
"""
from alembic import op
import sqlalchemy as sa
revision = "795b20b85b4b"
down_revision = "05c07bf07c00"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"llm_provider__user_group",
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["llm_provider_id"],
["llm_provider.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
)
op.add_column(
"llm_provider",
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
)
def downgrade() -> None:
op.drop_table("llm_provider__user_group")
op.drop_column("llm_provider", "is_public")

View File

@@ -1,107 +0,0 @@
"""associate index attempts with ccpair
Revision ID: 8a87bd6ec550
Revises: 4ea2c93919c1
Create Date: 2024-07-22 15:15:52.558451
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8a87bd6ec550"
down_revision = "4ea2c93919c1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add the new connector_credential_pair_id column
op.add_column(
"index_attempt",
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True),
)
# Create a foreign key constraint to the connector_credential_pair table
op.create_foreign_key(
"fk_index_attempt_connector_credential_pair_id",
"index_attempt",
"connector_credential_pair",
["connector_credential_pair_id"],
["id"],
)
# Populate the new connector_credential_pair_id column using existing connector_id and credential_id
op.execute(
"""
UPDATE index_attempt ia
SET connector_credential_pair_id = (
SELECT id FROM connector_credential_pair ccp
WHERE
(ia.connector_id IS NULL OR ccp.connector_id = ia.connector_id)
AND (ia.credential_id IS NULL OR ccp.credential_id = ia.credential_id)
LIMIT 1
)
WHERE ia.connector_id IS NOT NULL OR ia.credential_id IS NOT NULL
"""
)
# For good measure
op.execute(
"""
DELETE FROM index_attempt
WHERE connector_credential_pair_id IS NULL
"""
)
# Make the new connector_credential_pair_id column non-nullable
op.alter_column("index_attempt", "connector_credential_pair_id", nullable=False)
# Drop the old connector_id and credential_id columns
op.drop_column("index_attempt", "connector_id")
op.drop_column("index_attempt", "credential_id")
# Update the index to use connector_credential_pair_id
op.create_index(
"ix_index_attempt_latest_for_connector_credential_pair",
"index_attempt",
["connector_credential_pair_id", "time_created"],
)
def downgrade() -> None:
# Add back the old connector_id and credential_id columns
op.add_column(
"index_attempt", sa.Column("connector_id", sa.Integer(), nullable=True)
)
op.add_column(
"index_attempt", sa.Column("credential_id", sa.Integer(), nullable=True)
)
# Populate the old connector_id and credential_id columns using the connector_credential_pair_id
op.execute(
"""
UPDATE index_attempt ia
SET connector_id = ccp.connector_id, credential_id = ccp.credential_id
FROM connector_credential_pair ccp
WHERE ia.connector_credential_pair_id = ccp.id
"""
)
# Make the old connector_id and credential_id columns non-nullable
op.alter_column("index_attempt", "connector_id", nullable=False)
op.alter_column("index_attempt", "credential_id", nullable=False)
# Drop the new connector_credential_pair_id column
op.drop_constraint(
"fk_index_attempt_connector_credential_pair_id",
"index_attempt",
type_="foreignkey",
)
op.drop_column("index_attempt", "connector_credential_pair_id")
op.create_index(
"ix_index_attempt_latest_for_connector_credential_pair",
"index_attempt",
["connector_id", "credential_id", "time_created"],
)

View File

@@ -7,7 +7,7 @@ Create Date: 2024-03-21 12:05:23.956734
"""
from alembic import op
import sqlalchemy as sa
from alembic_utils import DocumentSource
from danswer.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
revision = "91fd3b470d1a"

View File

@@ -1,26 +0,0 @@
"""add expiry time
Revision ID: 91ffac7e65b3
Revises: bc9771dccadf
Create Date: 2024-06-24 09:39:56.462242
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "91ffac7e65b3"
down_revision = "795b20b85b4b"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"user", sa.Column("oidc_expiry", sa.DateTime(timezone=True), nullable=True)
)
def downgrade() -> None:
op.drop_column("user", "oidc_expiry")

View File

@@ -16,6 +16,7 @@ depends_on: None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"connector_credential_pair",
"last_attempt_status",
@@ -28,9 +29,11 @@ def upgrade() -> None:
),
nullable=True,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"connector_credential_pair",
"last_attempt_status",
@@ -43,3 +46,4 @@ def downgrade() -> None:
),
nullable=False,
)
# ### end Alembic commands ###

View File

@@ -10,7 +10,7 @@ from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import ENUM
from alembic_utils import DocumentSource
from danswer.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
revision = "b156fa702355"

View File

@@ -1,24 +0,0 @@
"""add tenant id to user model
Revision ID: b25c363470f3
Revises: 1f60f60c3401
Create Date: 2024-08-29 17:03:20.794120
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b25c363470f3"
down_revision = "1f60f60c3401"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("user", sa.Column("tenant_id", sa.Text(), nullable=True))
def downgrade() -> None:
op.drop_column("user", "tenant_id")

View File

@@ -1,57 +0,0 @@
"""Add index_attempt_errors table
Revision ID: c5b692fa265c
Revises: 4a951134c801
Create Date: 2024-08-08 14:06:39.581972
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "c5b692fa265c"
down_revision = "4a951134c801"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"index_attempt_errors",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
sa.Column("batch", sa.Integer(), nullable=True),
sa.Column(
"doc_summaries",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column("traceback", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["index_attempt_id"],
["index_attempt.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"index_attempt_id",
"index_attempt_errors",
["time_created"],
unique=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
op.drop_table("index_attempt_errors")
# ### end Alembic commands ###

View File

@@ -19,9 +19,6 @@ depends_on: None = None
def upgrade() -> None:
op.drop_table("deletion_attempt")
# Remove the DeletionStatus enum
op.execute("DROP TYPE IF EXISTS deletionstatus;")
def downgrade() -> None:
op.create_table(

View File

@@ -1,31 +0,0 @@
"""Remove _alt suffix from model_name
Revision ID: d9ec13955951
Revises: da4c21c69164
Create Date: 2024-08-20 16:31:32.955686
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "d9ec13955951"
down_revision = "da4c21c69164"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE embedding_model
SET model_name = regexp_replace(model_name, '__danswer_alt_index$', '')
WHERE model_name LIKE '%__danswer_alt_index'
"""
)
def downgrade() -> None:
# We can't reliably add the __danswer_alt_index suffix back, so we'll leave this empty
pass

View File

@@ -1,66 +0,0 @@
"""chosen_assistants changed to jsonb
Revision ID: da4c21c69164
Revises: c5b692fa265c
Create Date: 2024-08-18 19:06:47.291491
"""
import json
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "da4c21c69164"
down_revision = "c5b692fa265c"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text('SELECT id, chosen_assistants FROM "user"')
)
op.drop_column(
'user',
"chosen_assistants",
)
op.add_column(
'user',
sa.Column(
"chosen_assistants",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
'UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id'
),
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
)
def downgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text('SELECT id, chosen_assistants FROM user')
)
op.drop_column(
'user',
"chosen_assistants",
)
op.add_column(
'user',
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
)
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
'UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id'
),
{"chosen_assistants": chosen_assistants, "id": id},
)

View File

@@ -8,13 +8,20 @@ Create Date: 2024-01-25 17:12:31.813160
from alembic import op
import sqlalchemy as sa
from sqlalchemy import table, column, String, Integer, Boolean
from alembic_utils import IndexModelStatus
from danswer.db.embedding_model import (
get_new_default_embedding_model,
get_old_default_embedding_model,
user_has_overridden_embedding_model,
)
from danswer.db.models import IndexModelStatus
# revision identifiers, used by Alembic.
revision = "dbaa756c2ccf"
down_revision = "7f726bad5367"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
@@ -33,32 +40,9 @@ def upgrade() -> None:
),
sa.PrimaryKeyConstraint("id"),
)
# Define the old default embedding model directly
old_embedding_model = {
"model_name": "sentence-transformers/all-distilroberta-v1",
"model_dim": 768,
"normalize": True,
"query_prefix": "",
"passage_prefix": "",
"index_name": "OPENSEARCH_INDEX_NAME",
"status": IndexModelStatus.PAST,
}
# Define the new default embedding model directly
new_embedding_model = {
"model_name": "intfloat/e5-small-v2",
"model_dim": 384,
"normalize": False,
"query_prefix": "query: ",
"passage_prefix": "passage: ",
"index_name": "danswer_chunk_intfloat_e5_small_v2",
"status": IndexModelStatus.PRESENT,
}
# Assume the user has not overridden the embedding model
user_overridden_embedding_model = False
# since all index attempts must be associated with an embedding model,
# need to put something in here to avoid nulls. On server startup,
# this value will be overriden
EmbeddingModel = table(
"embedding_model",
column("id", Integer),
@@ -68,23 +52,45 @@ def upgrade() -> None:
column("query_prefix", String),
column("passage_prefix", String),
column("index_name", String),
column("status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False)),
column(
"status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False)
),
)
# Insert the old embedding model
# insert an embedding model row that corresponds to the embedding model
# the user selected via env variables before this change. This is needed since
# all index_attempts must be associated with an embedding model, so without this
# we will run into violations of non-null contraints
old_embedding_model = get_old_default_embedding_model()
op.bulk_insert(
EmbeddingModel,
[
old_embedding_model
{
"model_name": old_embedding_model.model_name,
"model_dim": old_embedding_model.model_dim,
"normalize": old_embedding_model.normalize,
"query_prefix": old_embedding_model.query_prefix,
"passage_prefix": old_embedding_model.passage_prefix,
"index_name": old_embedding_model.index_name,
"status": old_embedding_model.status,
}
],
)
# If the user has not overridden the embedding model, insert the new default model
if not user_overridden_embedding_model:
# if the user has not overridden the default embedding model via env variables,
# insert the new default model into the database to auto-upgrade them
if not user_has_overridden_embedding_model():
new_embedding_model = get_new_default_embedding_model(is_present=False)
op.bulk_insert(
EmbeddingModel,
[
new_embedding_model
{
"model_name": new_embedding_model.model_name,
"model_dim": new_embedding_model.model_dim,
"normalize": new_embedding_model.normalize,
"query_prefix": new_embedding_model.query_prefix,
"passage_prefix": new_embedding_model.passage_prefix,
"index_name": new_embedding_model.index_name,
"status": IndexModelStatus.FUTURE,
}
],
)
@@ -123,10 +129,11 @@ def upgrade() -> None:
postgresql_where=sa.text("status = 'FUTURE'"),
)
def downgrade() -> None:
op.drop_constraint(
"index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey"
)
op.drop_column("index_attempt", "embedding_model_id")
op.drop_table("embedding_model")
op.execute("DROP TYPE IF EXISTS indexmodelstatus;")
op.execute("DROP TYPE indexmodelstatus;")

View File

@@ -1,58 +0,0 @@
"""Added input prompts
Revision ID: e1392f05e840
Revises: 08a1eda20fe1
Create Date: 2024-07-13 19:09:22.556224
"""
import fastapi_users_db_sqlalchemy
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "e1392f05e840"
down_revision = "08a1eda20fe1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"inputprompt",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("prompt", sa.String(), nullable=False),
sa.Column("content", sa.String(), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.Column("is_public", sa.Boolean(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"inputprompt__user",
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["input_prompt_id"],
["inputprompt.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["inputprompt.id"],
),
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
)
def downgrade() -> None:
op.drop_table("inputprompt__user")
op.drop_table("inputprompt")

View File

@@ -8,7 +8,7 @@ Create Date: 2024-03-14 18:06:08.523106
from alembic import op
import sqlalchemy as sa
from alembic_utils import DocumentSource
from danswer.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
revision = "e50154680a5c"

View File

@@ -1,28 +0,0 @@
"""Added alternate model to chat message
Revision ID: ee3f4b47fad5
Revises: 2d2304e27d8c
Create Date: 2024-08-12 00:11:50.915845
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ee3f4b47fad5"
down_revision = "2d2304e27d8c"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column("overridden_model", sa.String(length=255), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_message", "overridden_model")

View File

@@ -1,172 +0,0 @@
"""embedding provider by provider type
Revision ID: f17bf3b0d9f1
Revises: 351faebd379d
Create Date: 2024-08-21 13:13:31.120460
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f17bf3b0d9f1"
down_revision = "351faebd379d"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add provider_type column to embedding_provider
op.add_column(
"embedding_provider",
sa.Column("provider_type", sa.String(50), nullable=True),
)
# Update provider_type with existing name values
op.execute("UPDATE embedding_provider SET provider_type = UPPER(name)")
# Make provider_type not nullable
op.alter_column("embedding_provider", "provider_type", nullable=False)
# Drop the foreign key constraint in embedding_model table
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
# Drop the existing primary key constraint
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
# Create a new primary key constraint on provider_type
op.create_primary_key(
"embedding_provider_pkey", "embedding_provider", ["provider_type"]
)
# Add provider_type column to embedding_model
op.add_column(
"embedding_model",
sa.Column("provider_type", sa.String(50), nullable=True),
)
# Update provider_type for existing embedding models
op.execute(
"""
UPDATE embedding_model
SET provider_type = (
SELECT provider_type
FROM embedding_provider
WHERE embedding_provider.id = embedding_model.cloud_provider_id
)
"""
)
# Drop the old id column from embedding_provider
op.drop_column("embedding_provider", "id")
# Drop the name column from embedding_provider
op.drop_column("embedding_provider", "name")
# Drop the default_model_id column from embedding_provider
op.drop_column("embedding_provider", "default_model_id")
# Drop the old cloud_provider_id column from embedding_model
op.drop_column("embedding_model", "cloud_provider_id")
# Create the new foreign key constraint
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["provider_type"],
["provider_type"],
)
def downgrade() -> None:
# Drop the foreign key constraint in embedding_model table
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
# Add back the cloud_provider_id column to embedding_model
op.add_column(
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
)
op.add_column("embedding_provider", sa.Column("id", sa.Integer(), nullable=True))
# Assign incrementing IDs to embedding providers
op.execute(
"""
CREATE SEQUENCE IF NOT EXISTS embedding_provider_id_seq;"""
)
op.execute(
"""
UPDATE embedding_provider SET id = nextval('embedding_provider_id_seq');
"""
)
# Update cloud_provider_id based on provider_type
op.execute(
"""
UPDATE embedding_model
SET cloud_provider_id = CASE
WHEN provider_type IS NULL THEN NULL
ELSE (
SELECT id
FROM embedding_provider
WHERE embedding_provider.provider_type = embedding_model.provider_type
)
END
"""
)
# Drop the provider_type column from embedding_model
op.drop_column("embedding_model", "provider_type")
# Add back the columns to embedding_provider
op.add_column("embedding_provider", sa.Column("name", sa.String(50), nullable=True))
op.add_column(
"embedding_provider", sa.Column("default_model_id", sa.Integer(), nullable=True)
)
# Drop the existing primary key constraint on provider_type
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
# Create the original primary key constraint on id
op.create_primary_key("embedding_provider_pkey", "embedding_provider", ["id"])
# Update name with existing provider_type values
op.execute(
"""
UPDATE embedding_provider
SET name = CASE
WHEN provider_type = 'OPENAI' THEN 'OpenAI'
WHEN provider_type = 'COHERE' THEN 'Cohere'
WHEN provider_type = 'GOOGLE' THEN 'Google'
WHEN provider_type = 'VOYAGE' THEN 'Voyage'
ELSE provider_type
END
"""
)
# Drop the provider_type column from embedding_provider
op.drop_column("embedding_provider", "provider_type")
# Recreate the foreign key constraint in embedding_model table
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["cloud_provider_id"],
["id"],
)
# Recreate the foreign key constraint in embedding_model table
op.create_foreign_key(
"fk_embedding_provider_default_model",
"embedding_provider",
"embedding_model",
["default_model_id"],
["id"],
)

View File

@@ -1,99 +0,0 @@
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from os import urandom
import os
from enum import Enum
ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET") or ""
def _get_trimmed_key(key: str) -> bytes:
encoded_key = key.encode()
key_length = len(encoded_key)
if key_length < 16:
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
elif key_length > 32:
key = key[:32]
elif key_length not in (16, 24, 32):
valid_lengths = [16, 24, 32]
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
return encoded_key
def encrypt_string(input_str: str) -> bytes:
if not ENCRYPTION_KEY_SECRET:
return input_str.encode()
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
iv = urandom(16)
padder = padding.PKCS7(algorithms.AES.block_size).padder()
padded_data = padder.update(input_str.encode()) + padder.finalize()
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
encryptor = cipher.encryptor()
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
return iv + encrypted_data
NUM_POSTPROCESSED_RESULTS = 20
class IndexModelStatus(str, Enum):
PAST = "PAST"
PRESENT = "PRESENT"
FUTURE = "FUTURE"
class RecencyBiasSetting(str, Enum):
FAVOR_RECENT = "favor_recent" # 2x decay rate
BASE_DECAY = "base_decay"
NO_DECAY = "no_decay"
# Determine based on query if to use base_decay or favor_recent
AUTO = "auto"
class SearchType(str, Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
INGESTION_API = "ingestion_api"
SLACK = "slack"
WEB = "web"
GOOGLE_DRIVE = "google_drive"
GMAIL = "gmail"
REQUESTTRACKER = "requesttracker"
GITHUB = "github"
GITLAB = "gitlab"
GURU = "guru"
BOOKSTACK = "bookstack"
CONFLUENCE = "confluence"
SLAB = "slab"
JIRA = "jira"
PRODUCTBOARD = "productboard"
FILE = "file"
NOTION = "notion"
ZULIP = "zulip"
LINEAR = "linear"
HUBSPOT = "hubspot"
DOCUMENT360 = "document360"
GONG = "gong"
GOOGLE_SITES = "google_sites"
ZENDESK = "zendesk"
LOOPIO = "loopio"
DROPBOX = "dropbox"
SHAREPOINT = "sharepoint"
TEAMS = "teams"
SALESFORCE = "salesforce"
DISCOURSE = "discourse"
AXERO = "axero"
CLICKUP = "clickup"
MEDIAWIKI = "mediawiki"
WIKIPEDIA = "wikipedia"
S3 = "s3"
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
NOT_APPLICABLE = "not_applicable"

View File

@@ -5,16 +5,19 @@ from danswer.access.utils import prefix_user
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.db.document import get_acccess_info_for_documents
from danswer.db.models import User
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.variable_functionality import fetch_versioned_implementation
def _get_access_for_documents(
document_ids: list[str],
db_session: Session,
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
) -> dict[str, DocumentAccess]:
document_access_info = get_acccess_info_for_documents(
db_session=db_session,
document_ids=document_ids,
cc_pair_to_delete=cc_pair_to_delete,
)
return {
document_id: DocumentAccess.build(user_ids, [], is_public)
@@ -25,13 +28,14 @@ def _get_access_for_documents(
def get_access_for_documents(
document_ids: list[str],
db_session: Session,
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
) -> dict[str, DocumentAccess]:
"""Fetches all access information for the given documents."""
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
"danswer.access.access", "_get_access_for_documents"
)
return versioned_get_access_for_documents_fn(
document_ids, db_session
document_ids, db_session, cc_pair_to_delete
) # type: ignore

View File

@@ -1,20 +1,21 @@
from typing import cast
from danswer.configs.constants import KV_USER_STORE_KEY
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.dynamic_configs.interface import JSON_ro
USER_STORE_KEY = "INVITED_USERS"
def get_invited_users() -> list[str]:
try:
store = get_dynamic_config_store()
return cast(list, store.load(KV_USER_STORE_KEY))
return cast(list, store.load(USER_STORE_KEY))
except ConfigNotFoundError:
return list()
def write_invited_users(emails: list[str]) -> int:
store = get_dynamic_config_store()
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
store.store(USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -3,27 +3,29 @@ from typing import Any
from typing import cast
from danswer.auth.schemas import UserRole
from danswer.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
from danswer.dynamic_configs.store import ConfigNotFoundError
from danswer.dynamic_configs.store import DynamicConfigStore
from danswer.server.manage.models import UserInfo
from danswer.server.manage.models import UserPreferences
NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
def set_no_auth_user_preferences(
store: DynamicConfigStore, preferences: UserPreferences
) -> None:
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
store.store(NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
try:
preferences_data = cast(
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
Mapping[str, Any], store.load(NO_AUTH_USER_PREFERENCES_KEY)
)
return UserPreferences(**preferences_data)
except ConfigNotFoundError:
return UserPreferences(chosen_assistants=None, default_model=None)
return UserPreferences(chosen_assistants=None)
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:

View File

@@ -5,20 +5,8 @@ from fastapi_users import schemas
class UserRole(str, Enum):
"""
User roles
- Basic can't perform any admin actions
- Admin can perform all admin actions
- Curator can perform admin actions for
groups they are curators of
- Global Curator can perform admin actions
for all groups they are a member of
"""
BASIC = "basic"
ADMIN = "admin"
CURATOR = "curator"
GLOBAL_CURATOR = "global_curator"
class UserStatus(str, Enum):
@@ -33,7 +21,6 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
tenant_id: str | None = None
class UserUpdate(schemas.BaseUserUpdate):

View File

@@ -1,19 +1,11 @@
from danswer.configs.app_configs import SECRET_JWT_KEY
from datetime import timedelta
import contextlib
import smtplib
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timezone
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Optional
from typing import Tuple
import jwt
from email_validator import EmailNotValidError
from email_validator import validate_email
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
@@ -46,7 +38,6 @@ from danswer.configs.app_configs import SMTP_PASS
from danswer.configs.app_configs import SMTP_PORT
from danswer.configs.app_configs import SMTP_SERVER
from danswer.configs.app_configs import SMTP_USER
from danswer.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.app_configs import WEB_DOMAIN
@@ -58,52 +49,27 @@ from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.engine import get_async_session
from danswer.db.engine import get_session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import User
from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation,
)
logger = setup_logger()
def validate_curator_request(groups: list | None, is_public: bool) -> None:
if is_public:
detail = "Curators cannot create public objects"
logger.error(detail)
raise HTTPException(
status_code=401,
detail=detail,
)
if not groups:
detail = "Curators must specify 1+ groups"
logger.error(detail)
raise HTTPException(
status_code=401,
detail=detail,
)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
if user and user.role == UserRole.ADMIN:
return True
return False
def verify_auth_setting() -> None:
if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]:
raise ValueError(
"User must choose a valid user authentication method: "
"disabled, basic, or google_oauth"
)
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
def get_display_email(email: str | None, space_less: bool = False) -> str:
@@ -126,36 +92,10 @@ def user_needs_to_be_verified() -> bool:
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
def verify_email_is_invited(email: str) -> None:
whitelist = get_invited_users()
if not whitelist:
return
if not email:
raise PermissionError("Email must be specified")
email_info = validate_email(email) # can raise EmailNotValidError
for email_whitelist in whitelist:
try:
# normalized emails are now being inserted into the db
# we can remove this normalization on read after some time has passed
email_info_whitelist = validate_email(email_whitelist)
except EmailNotValidError:
continue
# oddly, normalization does not include lowercasing the user part of the
# email address ... which we want to allow
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
return
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
whitelist = get_invited_users()
if (whitelist and email not in whitelist) or not email:
raise PermissionError("User not on allowed user whitelist")
def verify_email_domain(email: str) -> None:
@@ -196,95 +136,18 @@ def send_user_verification_email(
s.login(SMTP_USER, SMTP_PASS)
s.send_message(msg)
def verify_sso_token(token: str) -> dict:
try:
payload = jwt.decode(token, "SSO_SECRET_KEY", algorithms=["HS256"])
if datetime.now(timezone.utc) > datetime.fromtimestamp(
payload["exp"], timezone.utc
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired"
)
return payload
except jwt.PyJWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)
async def get_or_create_user(email: str, user_id: str) -> User:
get_async_session_context = contextlib.asynccontextmanager(get_async_session)
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
async with get_async_session_context() as session:
async with get_user_db_context(session) as user_db:
existing_user = await user_db.get_by_email(email)
if existing_user:
return existing_user
new_user = {
"email": email,
"id": uuid.UUID(user_id),
"role": UserRole.BASIC,
"oidc_expiry": None,
"default_model": None,
"chosen_assistants": None,
"hashed_password": "p",
"is_active": True,
"is_superuser": False,
"is_verified": True,
}
created_user: User = await user_db.create(new_user)
return created_user
async def create_user_session(user: User, tenant_id: str) -> str:
# Create a payload user information and tenant_id
payload = {
"sub": str(user.id),
"email": user.email,
"tenant_id": tenant_id,
"exp": datetime.utcnow() + timedelta(seconds=SESSION_EXPIRE_TIME_SECONDS)
}
token = jwt.encode(payload, SECRET_JWT_KEY, algorithm="HS256")
return token
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
async def sso_authenticate(
self,
email: str,
tenant_id: str,
) -> User:
try:
user = await self.get_by_email(email)
except Exception:
# user_create = UserCreate(email=email, password=secrets.token_urlsafe(32))
user_create = UserCreate(
role=UserRole.BASIC, password="password", email=email, is_verified=True
)
user = await self.create(user_create)
# Update user with tenant information if needed
if user.tenant_id != tenant_id:
await self.user_db.update(user, {"tenant_id": tenant_id})
return user
async def create(
self,
user_create: schemas.UC | UserCreate,
safe: bool = False,
request: Optional[Request] = None,
) -> models.UP:
verify_email_is_invited(user_create.email)
verify_email_in_whitelist(user_create.email)
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
@@ -292,7 +155,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
return await super().create(user_create, safe=safe, request=request) # type: ignore
async def oauth_callback(
@@ -311,7 +173,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_in_whitelist(account_email)
verify_email_domain(account_email)
user = await super().oauth_callback( # type: ignore
return await super().oauth_callback( # type: ignore
oauth_name=oauth_name,
access_token=access_token,
account_id=account_id,
@@ -323,23 +185,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
is_verified_by_default=is_verified_by_default,
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
if user.oidc_expiry and not TRACK_EXTERNAL_IDP_EXPIRY:
await self.user_db.update(user, update_dict={"oidc_expiry": None})
return user
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
logger.notice(f"User {user.id} has registered.")
logger.info(f"User {user.id} has registered.")
optional_telemetry(
record_type=RecordType.SIGN_UP,
data={"action": "create"},
@@ -349,14 +198,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def on_after_forgot_password(
self, user: User, token: str, request: Optional[Request] = None
) -> None:
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
logger.info(f"User {user.id} has forgot their password. Reset token: {token}")
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
) -> None:
verify_email_domain(user.email)
logger.notice(
logger.info(
f"Verification requested for user {user.id}. Verification token: {token}"
)
@@ -378,10 +227,9 @@ cookie_transport = CookieTransport(
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
strategy = DatabaseStrategy(
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
)
return strategy
auth_backend = AuthenticationBackend(
@@ -479,12 +327,6 @@ async def double_check_user(
detail="Access denied. User is not verified.",
)
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User's OIDC token has expired.",
)
return user
@@ -494,28 +336,6 @@ async def current_user(
return await double_check_user(user)
async def current_curator_or_admin_user(
user: User | None = Depends(current_user),
) -> User | None:
if DISABLE_AUTH:
return None
if not user or not hasattr(user, "role"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated or lacks role information.",
)
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
if user.role not in allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not a curator or admin.",
)
return user
async def current_admin_user(user: User | None = Depends(current_user)) -> User | None:
if DISABLE_AUTH:
return None
@@ -523,12 +343,6 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User must be an admin to perform this action.",
detail="Access denied. User is not an admin.",
)
return user
def get_default_admin_user_emails_() -> list[str]:
# No default seeding available for Danswer MIT
return []

View File

@@ -1,19 +1,10 @@
from danswer.configs.app_configs import MULTI_TENANT
from danswer.background.update import get_all_tenant_ids
import json
from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery # type: ignore
from celery.contrib.abortable import AbortableTask # type: ignore
from celery.exceptions import TaskRevokedError
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.celery_utils import should_kick_off_deletion_of_cc_pair
from danswer.background.celery.celery_utils import should_prune_cc_pair
from danswer.background.celery.celery_utils import should_sync_doc_set
from danswer.background.connector_deletion import delete_connector_credential_pair
@@ -23,8 +14,6 @@ from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import POSTGRES_CELERY_APP_NAME
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
@@ -49,9 +38,7 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=POSTGRES_CELERY_APP_NAME
)
connection_string = build_connection_string(db_api=SYNC_DB_API)
celery_broker_url = f"sqla+{connection_string}"
celery_backend_url = f"db+{connection_string}"
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
@@ -70,12 +57,11 @@ _SYNC_BATCH_SIZE = 100
def cleanup_connector_credential_pair_task(
connector_id: int,
credential_id: int,
tenant_id: str | None
) -> int:
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
or updating the ACL"""
engine = get_sqlalchemy_engine(schema=tenant_id)
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
@@ -105,7 +91,6 @@ def cleanup_connector_credential_pair_task(
db_session=db_session,
document_index=document_index,
cc_pair=cc_pair,
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(f"Failed to run connector_deletion due to {e}")
@@ -114,8 +99,8 @@ def cleanup_connector_credential_pair_task(
@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str | None) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
def prune_documents_task(connector_id: int, credential_id: int) -> None:
"""connector pruning task. For a cc pair, this task pulls all docuement IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
with Session(get_sqlalchemy_engine()) as db_session:
@@ -172,7 +157,6 @@ def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str |
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(
@@ -183,7 +167,7 @@ def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str |
@build_celery_task_wrapper(name_document_set_sync_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_document_set_task(document_set_id: int, tenant_id: str | None) -> None:
def sync_document_set_task(document_set_id: int) -> None:
"""For document sets marked as not up to date, sync the state from postgres
into the datastore. Also handles deletions."""
@@ -216,7 +200,7 @@ def sync_document_set_task(document_set_id: int, tenant_id: str | None) -> None:
]
document_index.update(update_requests=update_requests)
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
try:
cursor = None
while True:
@@ -267,10 +251,10 @@ def sync_document_set_task(document_set_id: int, tenant_id: str | None) -> None:
name="check_for_document_sets_sync_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_document_sets_sync_task(tenant_id: str | None) -> None:
def check_for_document_sets_sync_task() -> None:
"""Runs periodically to check if any sync tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
@@ -283,151 +267,15 @@ def check_for_document_sets_sync_task(tenant_id: str | None) -> None:
)
@celery_app.task(
name="check_for_cc_pair_deletion_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_cc_pair_deletion_task(tenant_id: str | None) -> None:
"""Runs periodically to check if any deletion tasks should be run"""
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
# check if any document sets are not synced
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
if should_kick_off_deletion_of_cc_pair(cc_pair, db_session):
logger.notice(f"Deleting the {cc_pair.name} connector credential pair")
cleanup_connector_credential_pair_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
tenant_id=tenant_id
),
)
@celery_app.task(
name="kombu_message_cleanup_task",
soft_time_limit=JOB_TIMEOUT,
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
ctx = {}
ctx["last_processed_id"] = 0
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
).scalar()
if not result:
return 0
while True:
if self.is_aborted():
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
b = kombu_message_cleanup_task_helper(ctx, db_session)
if not b:
break
db_session.commit()
if ctx["deleted"] > 0:
logger.info(f"Deleted {ctx['deleted']} orphaned messages from kombu_message.")
return ctx["deleted"]
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
"""
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
This function retrieves messages from the `kombu_message` table that are no longer visible and
older than a specified interval. It checks if the corresponding task_id exists in the
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
Args:
ctx (dict): A context dictionary containing configuration parameters such as:
- 'cleanup_age' (int): The age in days after which messages are considered old.
- 'page_limit' (int): The maximum number of messages to process in one batch.
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
- 'deleted' (int): A counter to track the number of deleted messages.
db_session (Session): The SQLAlchemy database session for executing queries.
Returns:
bool: Returns True if there are more rows to process, False if not.
"""
query = text(
"""
SELECT id, timestamp, payload
FROM kombu_message WHERE visible = 'false'
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
AND id > :last_processed_id
ORDER BY id
LIMIT :page_limit
"""
)
kombu_messages = db_session.execute(
query,
{
"interval_days": f"{ctx['cleanup_age']} days",
"page_limit": ctx["page_limit"],
"last_processed_id": ctx["last_processed_id"],
},
).fetchall()
if len(kombu_messages) == 0:
return False
for msg in kombu_messages:
payload = json.loads(msg[2])
task_id = payload["headers"]["id"]
# Check if task_id exists in celery_taskmeta
task_exists = db_session.execute(
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
{"task_id": task_id},
).fetchone()
# If task_id does not exist, delete the message
if not task_exists:
result = db_session.execute(
text("DELETE FROM kombu_message WHERE id = :message_id"),
{"message_id": msg[0]},
)
if result.rowcount > 0: # type: ignore
ctx["deleted"] += 1
else:
task_name = payload["headers"]["task"]
logger.warning(
f"Message found for task older than {ctx['cleanup_age']} days. "
f"id={task_id} name={task_name}"
)
ctx["last_processed_id"] = msg[0]
return True
@celery_app.task(
name="check_for_prune_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task(tenant_id: str | None) -> None:
def check_for_prune_task() -> None:
"""Runs periodically to check if any prune tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
all_cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in all_cc_pairs:
@@ -442,7 +290,6 @@ def check_for_prune_task(tenant_id: str | None) -> None:
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
tenant_id=tenant_id,
)
)
@@ -450,33 +297,17 @@ def check_for_prune_task(tenant_id: str | None) -> None:
#####
# Celery Beat (Periodic Tasks) Settings
#####
def schedule_tenant_tasks() -> None:
tenants = get_all_tenant_ids()
for tenant_id in tenants:
# Schedule tasks specific to each tenant
celery_app.conf.beat_schedule[f"check-for-document-set-sync-{tenant_id}"] = {
"task": "check_for_document_sets_sync_task",
"schedule": timedelta(seconds=5),
"args": (tenant_id,),
}
celery_app.conf.beat_schedule[f"check-for-cc-pair-deletion-{tenant_id}"] = {
"task": "check_for_cc_pair_deletion_task",
"schedule": timedelta(seconds=5),
"args": (tenant_id,),
}
celery_app.conf.beat_schedule[f"check-for-prune-{tenant_id}"] = {
celery_app.conf.beat_schedule = {
"check-for-document-set-sync": {
"task": "check_for_document_sets_sync_task",
"schedule": timedelta(seconds=5),
},
}
celery_app.conf.beat_schedule.update(
{
"check-for-prune": {
"task": "check_for_prune_task",
"schedule": timedelta(seconds=5),
"args": (tenant_id,),
}
# Schedule tasks that are not tenant-specific
celery_app.conf.beat_schedule["kombu-message-cleanup"] = {
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
"args": (tenant_id,),
}
schedule_tenant_tasks()
},
}
)

View File

@@ -6,8 +6,8 @@ from sqlalchemy.orm import Session
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
@@ -16,14 +16,10 @@ from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.engine import get_db_current_time
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import DocumentSet
from danswer.db.models import TaskQueueState
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import get_latest_task_by_type
@@ -33,54 +29,24 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def _get_deletion_status(
def get_deletion_status(
connector_id: int, credential_id: int, db_session: Session
) -> TaskQueueState | None:
) -> DeletionAttemptSnapshot | None:
cleanup_task_name = name_cc_cleanup_task(
connector_id=connector_id, credential_id=credential_id
)
return get_latest_task(task_name=cleanup_task_name, db_session=db_session)
task_state = get_latest_task(task_name=cleanup_task_name, db_session=db_session)
def get_deletion_attempt_snapshot(
connector_id: int, credential_id: int, db_session: Session
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
if not deletion_task:
if not task_state:
return None
return DeletionAttemptSnapshot(
connector_id=connector_id,
credential_id=credential_id,
status=deletion_task.status,
status=task_state.status,
)
def should_kick_off_deletion_of_cc_pair(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> bool:
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return False
if check_deletion_attempt_is_allowed(cc_pair, db_session):
return False
deletion_task = _get_deletion_status(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
db_session=db_session,
)
if deletion_task and check_task_is_live_and_not_timed_out(
deletion_task,
db_session,
# 1 hour timeout
timeout=60 * 60,
):
return False
return True
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
if document_set.is_up_to_date:
return False
@@ -92,7 +58,7 @@ def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
return False
logger.info(f"Document set {document_set.id} syncing now.")
logger.info(f"Document set {document_set.id} syncing now!")
return True
@@ -114,7 +80,7 @@ def should_prune_cc_pair(
return True
return False
if not ALLOW_SIMULTANEOUS_PRUNING:
if PREVENT_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
@@ -123,9 +89,11 @@ def should_prune_cc_pair(
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
logger.info("Another Connector is already pruning. Skipping.")
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
return False
if not last_pruning_task.start_time:

View File

@@ -10,6 +10,8 @@ are multiple connector / credential pairs that have indexed it
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
import time
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
@@ -22,8 +24,10 @@ from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document_connector_cnts
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.document_set import (
mark_cc_pair__document_set_relationships_to_be_deleted__no_commit,
)
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import ConnectorCredentialPair
@@ -31,11 +35,6 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import noop_fallback
from danswer.configs.app_configs import DEFAULT_SCHEMA
logger = setup_logger()
@@ -47,13 +46,12 @@ def delete_connector_credential_pair_batch(
connector_id: int,
credential_id: int,
document_index: DocumentIndex,
tenant_id: str | None
) -> None:
"""
Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore
it gets permanently deleted.
"""
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
# acquire lock for all documents in this batch so that indexing can't
# override the deletion
with prepare_to_modify_documents(
@@ -80,37 +78,25 @@ def delete_connector_credential_pair_batch(
document_ids_to_update = [
document_id for document_id, cnt in document_connector_cnts if cnt > 1
]
# maps document id to list of document set names
new_doc_sets_for_documents: dict[str, set[str]] = {
document_id_and_document_set_names_tuple[0]: set(
document_id_and_document_set_names_tuple[1]
)
for document_id_and_document_set_names_tuple in fetch_document_sets_for_documents(
db_session=db_session,
document_ids=document_ids_to_update,
)
}
# determine future ACLs for documents in batch
access_for_documents = get_access_for_documents(
document_ids=document_ids_to_update,
db_session=db_session,
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
# update Vespa
logger.debug(f"Updating documents: {document_ids_to_update}")
update_requests = [
UpdateRequest(
document_ids=[document_id],
access=access,
document_sets=new_doc_sets_for_documents[document_id],
)
for document_id, access in access_for_documents.items()
]
logger.debug(f"Updating documents: {document_ids_to_update}")
document_index.update(update_requests=update_requests)
# clean up Postgres
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids_to_update,
@@ -122,11 +108,52 @@ def delete_connector_credential_pair_batch(
db_session.commit()
def cleanup_synced_entities(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> None:
"""Updates the document sets associated with the connector / credential pair,
then relies on the document set sync script to kick off Celery jobs which will
sync these updates to Vespa.
Waits until the document sets are synced before returning."""
logger.info(f"Cleaning up Document Sets for CC Pair with ID: '{cc_pair.id}'")
document_sets_ids_to_sync = list(
mark_cc_pair__document_set_relationships_to_be_deleted__no_commit(
cc_pair_id=cc_pair.id,
db_session=db_session,
)
)
db_session.commit()
# wait till all document sets are synced before continuing
while True:
all_synced = True
document_sets = get_document_sets_by_ids(
db_session=db_session, document_set_ids=document_sets_ids_to_sync
)
for document_set in document_sets:
if not document_set.is_up_to_date:
all_synced = False
if all_synced:
break
# wait for 30 seconds before checking again
db_session.commit() # end transaction
logger.info(
f"Document sets '{document_sets_ids_to_sync}' not synced yet, waiting 30s"
)
time.sleep(30)
logger.info(
f"Finished cleaning up Document Sets for CC Pair with ID: '{cc_pair.id}'"
)
def delete_connector_credential_pair(
db_session: Session,
document_index: DocumentIndex,
cc_pair: ConnectorCredentialPair,
tenant_id: str | None
) -> int:
connector_id = cc_pair.connector_id
credential_id = cc_pair.credential_id
@@ -138,7 +165,6 @@ def delete_connector_credential_pair(
connector_id=connector_id,
credential_id=credential_id,
limit=_DELETION_BATCH_SIZE,
)
if not documents:
break
@@ -148,37 +174,20 @@ def delete_connector_credential_pair(
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
tenant_id=tenant_id,
)
num_docs_deleted += len(documents)
# Clean up document sets / access information from Postgres
# and sync these updates to Vespa
# TODO: add user group cleanup with `fetch_versioned_implementation`
cleanup_synced_entities(cc_pair, db_session)
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair.id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=connector_id,
@@ -190,11 +199,11 @@ def delete_connector_credential_pair(
connector_id=connector_id,
)
if not connector or not len(connector.credentials):
logger.info("Found no credentials left for connector, deleting connector")
logger.debug("Found no credentials left for connector, deleting connector")
db_session.delete(connector)
db_session.commit()
logger.notice(
logger.info(
"Successfully deleted connector_credential_pair with connector_id:"
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
)

View File

@@ -41,12 +41,6 @@ def _initializer(
return func(*args, **kwargs)
def _run_in_process(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
) -> None:
_initializer(func, args, kwargs)
@dataclass
class SimpleJob:
"""Drop in replacement for `dask.distributed.Future`"""
@@ -119,7 +113,7 @@ class SimpleJobClient:
job_id = self.job_id_counter
self.job_id_counter += 1
process = Process(target=_run_in_process, args=(func, args), daemon=True)
process = Process(target=_initializer(func=func, args=args), daemon=True)
job = SimpleJob(id=job_id, process=process)
process.start()

View File

@@ -1,4 +1,3 @@
import time
import traceback
from datetime import datetime
@@ -6,22 +5,22 @@ from datetime import timedelta
from datetime import timezone
from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
from danswer.background.indexing.tracer import DanswerTracer
from danswer.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from danswer.configs.app_configs import INDEXING_TRACER_INTERVAL
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
from danswer.connectors.connector_runner import ConnectorRunner
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import IndexAttemptMetadata
from danswer.connectors.models import InputType
from danswer.db.connector import disable_connector
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_in_progress__no_commit
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import IndexAttempt
@@ -36,16 +35,13 @@ from danswer.utils.variable_functionality import global_version
logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
def _get_connector_runner(
def _get_document_generator(
db_session: Session,
attempt: IndexAttempt,
start_time: datetime,
end_time: datetime,
tenant_id: str | None
) -> ConnectorRunner:
) -> GenerateDocumentsOutput:
"""
NOTE: `start_time` and `end_time` are only used for poll connectors
@@ -53,38 +49,48 @@ def _get_connector_runner(
are the complete list of existing documents of the connector. If the task
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
"""
task = attempt.connector_credential_pair.connector.input_type
task = attempt.connector.input_type
try:
runnable_connector = instantiate_connector(
attempt.connector_credential_pair.connector.source,
attempt.connector.source,
task,
attempt.connector_credential_pair.connector.connector_specific_config,
attempt.connector_credential_pair.credential,
attempt.connector.connector_specific_config,
attempt.credential,
db_session,
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
# since we failed to even instantiate the connector, we pause the CCPair since
# it will never succeed
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector_credential_pair.connector.id,
credential_id=attempt.connector_credential_pair.credential.id,
status=ConnectorCredentialPairStatus.PAUSED,
)
disable_connector(attempt.connector.id, db_session)
raise e
return ConnectorRunner(
connector=runnable_connector, time_range=(start_time, end_time)
)
if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector)
doc_batch_generator = runnable_connector.load_from_state()
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if attempt.connector_id is None or attempt.credential_id is None:
raise ValueError(
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
f"can't fetch time range."
)
logger.info(f"Polling for updates between {start_time} and {end_time}")
doc_batch_generator = runnable_connector.poll_source(
start=start_time.timestamp(), end=end_time.timestamp()
)
else:
# Event types cannot be handled by a background type
raise RuntimeError(f"Invalid task type: {task}")
return doc_batch_generator
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None
) -> None:
"""
1. Get documents which are either new or updated from specified application
@@ -92,65 +98,48 @@ def _run_indexing(
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
start_time = time.time()
search_settings = index_attempt.search_settings
index_name = search_settings.index_name
db_embedding_model = index_attempt.embedding_model
index_name = db_embedding_model.index_name
# Only update cc-pair status for primary index jobs
# Secondary index syncs at the end when swapping
is_primary = search_settings.status == IndexModelStatus.PRESENT
is_primary = index_attempt.embedding_model.status == IndexModelStatus.PRESENT
# Indexing is only done into one index at a time
document_index = get_default_document_index(
primary_index_name=index_name, secondary_index_name=None
)
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
embedding_model = DefaultIndexingEmbedder(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
api_key=db_embedding_model.api_key,
provider_type=db_embedding_model.provider_type,
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt.id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE),
or (db_embedding_model.status == IndexModelStatus.FUTURE),
db_session=db_session,
tenant_id=tenant_id,
)
db_cc_pair = index_attempt.connector_credential_pair
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
db_connector = index_attempt.connector
db_credential = index_attempt.credential
last_successful_index_time = (
db_connector.indexing_start.timestamp()
if index_attempt.from_beginning and db_connector.indexing_start is not None
else (
0.0
if index_attempt.from_beginning
else get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
search_settings=index_attempt.search_settings,
db_session=db_session,
)
0.0
if index_attempt.from_beginning
else get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
embedding_model=index_attempt.embedding_model,
db_session=db_session,
)
)
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = DanswerTracer()
tracer.start()
tracer.snap()
index_attempt_md = IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
)
batch_num = 0
net_doc_change = 0
document_count = 0
chunk_count = 0
@@ -169,32 +158,23 @@ def _run_indexing(
datetime(1970, 1, 1, tzinfo=timezone.utc),
)
connector_runner = _get_connector_runner(
doc_batch_generator = _get_document_generator(
db_session=db_session,
attempt=index_attempt,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id
)
all_connector_doc_ids: set[str] = set()
tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
for doc_batch in connector_runner.run():
for doc_batch in doc_batch_generator:
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
db_session.refresh(db_connector)
if (
(
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
and search_settings.status != IndexModelStatus.FUTURE
)
# if it's deleting, we don't care if this is a secondary index
or db_cc_pair.status == ConnectorCredentialPairStatus.DELETING
db_connector.disabled
and db_embedding_model.status != IndexModelStatus.FUTURE
):
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
@@ -202,32 +182,19 @@ def _run_indexing(
db_session.refresh(index_attempt)
if index_attempt.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError(f"Index Attempt was canceled, status is {index_attempt.status}")
raise RuntimeError("Index Attempt was canceled")
batch_description = []
for doc in doc_batch:
batch_description.append(doc.to_short_descriptor())
doc_size = 0
for section in doc.sections:
doc_size += len(section.text)
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
logger.warning(
f"Document size: doc='{doc.to_short_descriptor()}' "
f"size={doc_size} "
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
)
logger.debug(f"Indexing batch of documents: {batch_description}")
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
new_docs, total_batch_chunks = indexing_pipeline(
document_batch=doc_batch,
index_attempt_metadata=index_attempt_md,
logger.debug(
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
)
batch_num += 1
new_docs, total_batch_chunks = indexing_pipeline(
documents=doc_batch,
index_attempt_metadata=IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
),
)
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch)
@@ -249,17 +216,6 @@ def _run_indexing(
docs_removed_from_index=0,
)
tracer_counter += 1
if (
INDEXING_TRACER_INTERVAL > 0
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
):
logger.debug(
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
)
tracer.snap()
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
run_end_dt = window_end
if is_primary:
update_connector_credential_pair(
@@ -270,7 +226,7 @@ def _run_indexing(
run_dt=run_end_dt,
)
except Exception as e:
logger.exception(
logger.info(
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
)
# Only mark the attempt as a complete failure if this is the first indexing window.
@@ -282,7 +238,7 @@ def _run_indexing(
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or db_connector.disabled
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
@@ -294,66 +250,17 @@ def _run_indexing(
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
connector_id=index_attempt.connector.id,
credential_id=index_attempt.credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
break
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
)
tracer.snap()
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
tracer.stop()
logger.debug("Memory tracer stopped.")
if (
index_attempt_md.num_exceptions > 0
and index_attempt_md.num_exceptions >= batch_num
):
mark_attempt_failed(
index_attempt,
db_session,
failure_reason="All batches exceptioned.",
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector_credential_pair.connector.id,
credential_id=index_attempt.connector_credential_pair.credential.id,
)
raise Exception(
f"Connector failed - All batches exceptioned: batches={batch_num}"
)
elapsed_time = time.time() - start_time
if index_attempt_md.num_exceptions == 0:
mark_attempt_succeeded(index_attempt, db_session)
logger.info(
f"Connector succeeded: "
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
else:
mark_attempt_partially_succeeded(index_attempt, db_session)
logger.info(
f"Connector completed with some errors: "
f"exceptions={index_attempt_md.num_exceptions} "
f"batches={batch_num} "
f"docs={document_count} "
f"chunks={chunk_count} "
f"elapsed={elapsed_time:.2f}s"
)
mark_attempt_succeeded(index_attempt, db_session)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
@@ -362,6 +269,13 @@ def _run_indexing(
run_dt=run_end_dt,
)
logger.info(
f"Indexed or refreshed {document_count} total documents for a total of {chunk_count} indexed chunks"
)
logger.info(
f"Connector successfully finished, elapsed time: {time.time() - start_time} seconds"
)
def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt:
# make sure that the index attempt can't change in between checking the
@@ -385,34 +299,42 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
)
# only commit once, to make sure this all happens in a single transaction
mark_attempt_in_progress(attempt, db_session)
mark_attempt_in_progress__no_commit(attempt)
if attempt.embedding_model.status != IndexModelStatus.PRESENT:
db_session.commit()
return attempt
def run_indexing_entrypoint(index_attempt_id: int, tenant_id: str | None, is_ee: bool = False) -> None:
def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
try:
if is_ee:
global_version.set_ee()
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
# make sure that it is valid to run this indexing attempt + mark it
# as in progress
attempt = _prepare_index_attempt(db_session, index_attempt_id)
logger.info(
f"Indexing starting for tenant {tenant_id}: " if tenant_id is not None else "" +
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
f"Running indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
_run_indexing(db_session, attempt, tenant_id)
_run_indexing(db_session, attempt)
logger.info(
f"Indexing finished for tenant {tenant_id}: " if tenant_id is not None else "" +
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
except Exception as e:
logger.exception(f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}")
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")

View File

@@ -1,77 +0,0 @@
import tracemalloc
from danswer.utils.logger import setup_logger
logger = setup_logger()
DANSWER_TRACEMALLOC_FRAMES = 10
class DanswerTracer:
def __init__(self) -> None:
self.snapshot_first: tracemalloc.Snapshot | None = None
self.snapshot_prev: tracemalloc.Snapshot | None = None
self.snapshot: tracemalloc.Snapshot | None = None
def start(self) -> None:
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
def stop(self) -> None:
tracemalloc.stop()
def snap(self) -> None:
snapshot = tracemalloc.take_snapshot()
# Filter out irrelevant frames (e.g., from tracemalloc itself or importlib)
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc
tracemalloc.Filter(
False, "<frozen importlib._bootstrap>"
), # Exclude importlib
tracemalloc.Filter(
False, "<frozen importlib._bootstrap_external>"
), # Exclude external importlib
)
)
if not self.snapshot_first:
self.snapshot_first = snapshot
if self.snapshot:
self.snapshot_prev = self.snapshot
self.snapshot = snapshot
def log_snapshot(self, numEntries: int) -> None:
if not self.snapshot:
return
stats = self.snapshot.statistics("traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer snap: {s}")
for line in s.traceback:
logger.debug(f"* {line}")
@staticmethod
def log_diff(
snap_current: tracemalloc.Snapshot,
snap_previous: tracemalloc.Snapshot,
numEntries: int,
) -> None:
stats = snap_current.compare_to(snap_previous, "traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.debug(f"* {line}")
def log_previous_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_prev:
return
DanswerTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries)
def log_first_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_first:
return
DanswerTracer.log_diff(self.snapshot, self.snapshot_first, numEntries)

View File

@@ -14,11 +14,8 @@ from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
def name_cc_cleanup_task(connector_id: int, credential_id: int, tenant_id: str | None = None) -> str:
task_name = f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
if tenant_id is not None:
task_name += f"_{tenant_id}"
return task_name
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
def name_document_set_sync_task(document_set_id: int) -> str:

View File

@@ -8,7 +8,6 @@ from dask.distributed import Future
from distributed import LocalCluster
from sqlalchemy.orm import Session
from sqlalchemy import text
from danswer.background.indexing.dask_utils import ResourceLogger
from danswer.background.indexing.job_client import SimpleJob
from danswer.background.indexing.job_client import SimpleJobClient
@@ -17,37 +16,30 @@ from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import get_last_attempt
from danswer.db.index_attempt import get_not_started_index_attempts
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Connector
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import LOG_LEVEL
from shared_configs.configs import MODEL_SERVER_PORT
from danswer.configs.app_configs import MULTI_TENANT
from sqlalchemy.exc import ProgrammingError
logger = setup_logger()
@@ -61,64 +53,41 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
def _should_create_new_indexing(
cc_pair: ConnectorCredentialPair,
connector: Connector,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
model: EmbeddingModel,
secondary_index_building: bool,
db_session: Session,
) -> bool:
connector = cc_pair.connector
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
if (
search_settings_instance.status == IndexModelStatus.PRESENT
and secondary_index_building
):
if model.status == IndexModelStatus.PRESENT and secondary_index_building:
return False
# When switching over models, always index at least once
if search_settings_instance.status == IndexModelStatus.FUTURE:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if connector.id == 0: # Ingestion API
return False
if model.status == IndexModelStatus.FUTURE and not last_index:
if connector.id == 0: # Ingestion API
return False
return True
# If the connector is paused or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if not cc_pair.status.is_active() or connector.id == 0:
# If the connector is disabled, don't index
# NOTE: during an embedding model switch over, we ignore this
# and index the disabled connectors as well (which is why this if
# statement is below the first condition above)
if connector.disabled:
return False
if not last_index:
return True
if connector.refresh_freq is None:
return False
if not last_index:
return True
# Only one scheduled/ongoing job per connector at a time
# this prevents cases where
# (1) the "latest" index_attempt is scheduled so we show
# that in the UI despite another index_attempt being in-progress
# (2) multiple scheduled index_attempts at a time
if (
last_index.status == IndexingStatus.NOT_STARTED
or last_index.status == IndexingStatus.IN_PROGRESS
):
# Only one scheduled job per connector at a time
# Can schedule another one if the current one is already running however
# Because the currently running one will not be until the latest time
# Note, this last index is for the given embedding model
if last_index.status == IndexingStatus.NOT_STARTED:
return False
current_db_time = get_db_current_time(db_session)
@@ -126,14 +95,24 @@ def _should_create_new_indexing(
return time_since_index.total_seconds() >= connector.refresh_freq
def _is_indexing_job_marked_as_finished(index_attempt: IndexAttempt | None) -> bool:
if index_attempt is None:
return False
return (
index_attempt.status == IndexingStatus.FAILED
or index_attempt.status == IndexingStatus.SUCCESS
)
def _mark_run_failed(
db_session: Session, index_attempt: IndexAttempt, failure_reason: str
) -> None:
"""Marks the `index_attempt` row as failed + updates the `
connector_credential_pair` to reflect that the run failed"""
logger.warning(
f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
f"Marking in-progress attempt 'connector: {index_attempt.connector_id}, "
f"credential: {index_attempt.credential_id}' as failed due to {failure_reason}"
)
mark_attempt_failed(
index_attempt=index_attempt,
@@ -145,15 +124,14 @@ def _mark_run_failed(
"""Main funcs"""
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None) -> None:
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
"""Creates new indexing jobs for each connector / credential pair which is:
1. Enabled
2. `refresh_frequency` time has passed since the last indexing run for this pair
3. There is not already an ongoing indexing attempt for this pair
"""
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
ongoing: set[tuple[int | None, int]] = set()
with Session(get_sqlalchemy_engine()) as db_session:
ongoing: set[tuple[int | None, int | None, int]] = set()
for attempt_id in existing_jobs:
attempt = get_index_attempt(
db_session=db_session, index_attempt_id=attempt_id
@@ -166,66 +144,62 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob], tenant_id
continue
ongoing.add(
(
attempt.connector_credential_pair_id,
attempt.search_settings_id,
attempt.connector_id,
attempt.credential_id,
attempt.embedding_model_id,
)
)
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
embedding_models = [get_current_db_embedding_model(db_session)]
secondary_embedding_model = get_secondary_db_embedding_model(db_session)
if secondary_embedding_model is not None:
embedding_models.append(secondary_embedding_model)
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
all_connectors = fetch_connectors(db_session)
for connector in all_connectors:
for association in connector.credentials:
for model in embedding_models:
credential = association.credential
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair in all_connector_credential_pairs:
for search_settings_instance in search_settings:
# Check if there is an ongoing indexing attempt for this connector credential pair
if (cc_pair.id, search_settings_instance.id) in ongoing:
continue
# Check if there is an ongoing indexing attempt for this connector + credential pair
if (connector.id, credential.id, model.id) in ongoing:
continue
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
if not _should_create_new_indexing(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings) > 1,
db_session=db_session,
):
continue
last_attempt = get_last_attempt(
connector.id, credential.id, model.id, db_session
)
if not _should_create_new_indexing(
connector=connector,
last_index=last_attempt,
model=model,
secondary_index_building=len(embedding_models) > 1,
db_session=db_session,
):
continue
create_index_attempt(
cc_pair.id, search_settings_instance.id, db_session
)
create_index_attempt(
connector.id, credential.id, model.id, db_session
)
def cleanup_indexing_jobs(
existing_jobs: dict[int, Future | SimpleJob],
tenant_id: str | None,
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
# clean up completed jobs
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
for attempt_id, job in existing_jobs.items():
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=attempt_id
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
if not job.done() and not _is_indexing_job_marked_as_finished(
index_attempt
):
continue
if job.status == "error":
logger.error(job.exception())
@@ -251,42 +225,38 @@ def cleanup_indexing_jobs(
)
# clean up in-progress jobs that were never completed
try:
connectors = fetch_connectors(db_session)
for connector in connectors:
in_progress_indexing_attempts = get_inprogress_index_attempts(
connector.id, db_session
)
for index_attempt in in_progress_indexing_attempts:
if index_attempt.id in existing_jobs:
# If index attempt is canceled, stop the run
if index_attempt.status == IndexingStatus.FAILED:
existing_jobs[index_attempt.id].cancel()
# check to see if the job has been updated in last `timeout_hours` hours, if not
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
# on the fact that the `time_updated` field is constantly updated every
# batch of documents indexed
current_db_time = get_db_current_time(db_session=db_session)
time_since_update = current_db_time - index_attempt.time_updated
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
existing_jobs[index_attempt.id].cancel()
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason="Indexing run frozen - no updates in the last three hours. "
"The run will be re-attempted at next scheduled indexing time.",
)
else:
# If job isn't known, simply mark it as failed
connectors = fetch_connectors(db_session)
for connector in connectors:
in_progress_indexing_attempts = get_inprogress_index_attempts(
connector.id, db_session
)
for index_attempt in in_progress_indexing_attempts:
if index_attempt.id in existing_jobs:
# If index attempt is canceled, stop the run
if index_attempt.status == IndexingStatus.FAILED:
existing_jobs[index_attempt.id].cancel()
# check to see if the job has been updated in last `timeout_hours` hours, if not
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
# on the fact that the `time_updated` field is constantly updated every
# batch of documents indexed
current_db_time = get_db_current_time(db_session=db_session)
time_since_update = current_db_time - index_attempt.time_updated
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
existing_jobs[index_attempt.id].cancel()
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
failure_reason="Indexing run frozen - no updates in the last three hours. "
"The run will be re-attempted at next scheduled indexing time.",
)
except ProgrammingError as _:
logger.debug(f"No Connector Table exists for: {tenant_id}")
pass
else:
# If job isn't known, simply mark it as failed
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
)
return existing_jobs_copy
@@ -294,37 +264,31 @@ def kickoff_indexing_jobs(
existing_jobs: dict[int, Future | SimpleJob],
client: Client | SimpleJobClient,
secondary_client: Client | SimpleJobClient,
tenant_id: str | None,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
engine = get_sqlalchemy_engine(schema=tenant_id)
engine = get_sqlalchemy_engine()
# Don't include jobs waiting in the Dask queue that just haven't started running
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
with Session(engine) as db_session:
# get_not_started_index_attempts orders its returned results from oldest to newest
# we must process attempts in a FIFO manner to prevent connector starvation
new_indexing_attempts = [
(attempt, attempt.search_settings)
(attempt, attempt.embedding_model)
for attempt in get_not_started_index_attempts(db_session)
if attempt.id not in existing_jobs
]
logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
logger.info(f"Found {len(new_indexing_attempts)} new indexing tasks.")
if not new_indexing_attempts:
return existing_jobs
indexing_attempt_count = 0
for attempt, search_settings in new_indexing_attempts:
for attempt, embedding_model in new_indexing_attempts:
use_secondary_index = (
search_settings.status == IndexModelStatus.FUTURE
if search_settings is not None
embedding_model.status == IndexModelStatus.FUTURE
if embedding_model is not None
else False
)
if attempt.connector_credential_pair.connector is None:
if attempt.connector is None:
logger.warning(
f"Skipping index attempt as Connector has been deleted: {attempt}"
)
@@ -333,7 +297,7 @@ def kickoff_indexing_jobs(
attempt, db_session, failure_reason="Connector is null"
)
continue
if attempt.connector_credential_pair.credential is None:
if attempt.credential is None:
logger.warning(
f"Skipping index attempt as Credential has been deleted: {attempt}"
)
@@ -347,78 +311,62 @@ def kickoff_indexing_jobs(
run = secondary_client.submit(
run_indexing_entrypoint,
attempt.id,
tenant_id,
global_version.get_is_ee_version(),
pure=False,
)
else:
run = client.submit(
run_indexing_entrypoint,
attempt.id,
tenant_id,
global_version.get_is_ee_version(),
pure=False,
)
if run:
if indexing_attempt_count == 0:
logger.info(
f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
)
indexing_attempt_count += 1
secondary_str = " (secondary index)" if use_secondary_index else ""
secondary_str = "(secondary index) " if use_secondary_index else ""
logger.info(
f"Indexing dispatched{secondary_str}: "
f"attempt_id={attempt.id} "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.credential_id}'"
f"Kicked off {secondary_str}"
f"indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
existing_jobs_copy[attempt.id] = run
if indexing_attempt_count > 0:
logger.info(
f"Indexing dispatch results: "
f"initial_pending={len(new_indexing_attempts)} "
f"started={indexing_attempt_count} "
f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
)
return existing_jobs_copy
def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
return [None]
with Session(get_sqlalchemy_engine(schema='public')) as session:
result = session.execute(text("""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')
"""))
tenant_ids = [row[0] for row in result]
valid_tenants = [tenant for tenant in tenant_ids if tenant is None or not tenant.startswith('pg_')]
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
db_embedding_model = get_current_db_embedding_model(db_session)
return valid_tenants
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if db_embedding_model.cloud_provider_id is None:
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
def update_loop(
delay: int = 10,
num_workers: int = NUM_INDEXING_WORKERS,
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
) -> None:
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
if DASK_JOB_CLIENT_ENABLED:
cluster_primary = LocalCluster(
n_workers=num_workers,
threads_per_worker=1,
# there are warning about high memory usage + "Event loop unresponsive"
# which are not relevant to us since our workers are expected to use a
# lot of memory + involve CPU intensive tasks that will not relinquish
# the event loop
silence_logs=logging.ERROR,
)
cluster_secondary = LocalCluster(
n_workers=num_secondary_workers,
n_workers=num_workers,
threads_per_worker=1,
silence_logs=logging.ERROR,
)
@@ -428,77 +376,43 @@ def update_loop(
client_primary.register_worker_plugin(ResourceLogger())
else:
client_primary = SimpleJobClient(n_workers=num_workers)
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
client_secondary = SimpleJobClient(n_workers=num_workers)
existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {}
existing_jobs: dict[int, Future | SimpleJob] = {}
logger.notice("Startup complete. Waiting for indexing jobs...")
while True:
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
logger.debug(f"Running update, current UTC time: {start_time_utc}")
logger.info(f"Running update, current UTC time: {start_time_utc}")
if existing_jobs:
logger.debug(
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
logger.info(
"Found existing indexing jobs: "
f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}"
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
)
try:
tenants = get_all_tenant_ids()
for tenant_id in tenants:
try:
logger.debug(f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}")
engine = get_sqlalchemy_engine(schema=tenant_id)
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
if not MULTI_TENANT:
search_settings = get_current_search_settings(db_session)
if search_settings.provider_type is None:
logger.notice("Running a first inference to warm up embedding model")
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
logger.notice("First inference complete.")
tenant_jobs = existing_jobs.get(tenant_id, {})
tenant_jobs = cleanup_indexing_jobs(
existing_jobs=tenant_jobs,
tenant_id=tenant_id
)
create_indexing_jobs(
existing_jobs=tenant_jobs,
tenant_id=tenant_id
)
tenant_jobs = kickoff_indexing_jobs(
existing_jobs=tenant_jobs,
client=client_primary,
secondary_client=client_secondary,
tenant_id=tenant_id,
)
existing_jobs[tenant_id] = tenant_jobs
except Exception as e:
logger.exception(f"Failed to process tenant {tenant_id or 'default'}: {e}")
with Session(get_sqlalchemy_engine()) as db_session:
check_index_swap(db_session)
existing_jobs = cleanup_indexing_jobs(existing_jobs=existing_jobs)
create_indexing_jobs(existing_jobs=existing_jobs)
existing_jobs = kickoff_indexing_jobs(
existing_jobs=existing_jobs,
client=client_primary,
secondary_client=client_secondary,
)
except Exception as e:
logger.exception(f"Failed to run update due to {e}")
sleep_time = delay - (time.time() - start)
if sleep_time > 0:
time.sleep(sleep_time)
def update__main() -> None:
set_is_ee_based_on_env_variable()
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)
logger.notice("Starting indexing service")
logger.info("Starting Indexing Loop")
update_loop()

View File

@@ -35,19 +35,14 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
def create_chat_chain(
chat_session_id: int,
db_session: Session,
prefetch_tool_calls: bool = True,
# Optional id at which we finish processing
stop_at_message_id: int | None = None,
) -> tuple[ChatMessage, list[ChatMessage]]:
"""Build the linear chain of messages without including the root message"""
mainline_messages: list[ChatMessage] = []
all_chat_messages = get_chat_messages_by_session(
chat_session_id=chat_session_id,
user_id=None,
db_session=db_session,
skip_permission_check=True,
prefetch_tool_calls=prefetch_tool_calls,
)
id_to_msg = {msg.id: msg for msg in all_chat_messages}
@@ -63,12 +58,7 @@ def create_chat_chain(
current_message: ChatMessage | None = root_message
while current_message is not None:
child_msg = current_message.latest_child_message
# Break if at the end of the chain
# or have reached the `final_id` of the submitted message
if not child_msg or (
stop_at_message_id and current_message.id == stop_at_message_id
):
if not child_msg:
break
current_message = id_to_msg.get(child_msg)

View File

@@ -1,24 +0,0 @@
input_prompts:
- id: -5
prompt: "Elaborate"
content: "Elaborate on the above, give me a more in depth explanation."
active: true
is_public: true
- id: -4
prompt: "Reword"
content: "Help me rewrite the following politely and concisely for professional communication:\n"
active: true
is_public: true
- id: -3
prompt: "Email"
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
active: true
is_public: true
- id: -2
prompt: "Debug"
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
active: true
is_public: true

View File

@@ -1,172 +1,107 @@
import yaml
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import INPUT_PROMPT_YAML
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
from danswer.db.models import Prompt as PromptDBModel
from danswer.db.models import Tool as ToolDBModel
from danswer.db.persona import get_prompt_by_name
from danswer.db.persona import upsert_persona
from danswer.db.persona import upsert_prompt
from danswer.search.enums import RecencyBiasSetting
def load_prompts_from_yaml(
db_session: Session,
prompts_yaml: str = PROMPTS_YAML
) -> None:
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
with open(prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_prompts = data.get("prompts", [])
for prompt in all_prompts:
upsert_prompt(
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
system_prompt=prompt["system"].strip(),
task_prompt=prompt["task"].strip(),
include_citations=prompt["include_citations"],
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)
with Session(get_sqlalchemy_engine()) as db_session:
for prompt in all_prompts:
upsert_prompt(
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
system_prompt=prompt["system"].strip(),
task_prompt=prompt["task"].strip(),
include_citations=prompt["include_citations"],
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)
def load_personas_from_yaml(
db_session: Session,
personas_yaml: str = PERSONAS_YAML,
default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
) -> None:
with open(personas_yaml, "r") as file:
data = yaml.safe_load(file)
all_personas = data.get("personas", [])
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
doc_set_ids = None
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
with Session(get_sqlalchemy_engine()) as db_session:
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
doc_set_ids = None
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
.first()
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
p_id = persona.get("id")
upsert_persona(
user=None,
# Negative to not conflict with existing personas
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
llm_model_provider_override=None,
llm_model_version_override=None,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
default_persona=True,
is_public=True,
db_session=db_session,
)
if image_gen_tool:
tool_ids.append(image_gen_tool.id)
llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")
# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"
existing_persona = (
db_session.query(Persona)
.filter(Persona.name == persona["name"])
.first()
)
upsert_persona(
user=None,
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
default_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)
def load_input_prompts_from_yaml(
db_session: Session,
input_prompts_yaml: str = INPUT_PROMPT_YAML
) -> None:
with open(input_prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_input_prompts = data.get("input_prompts", [])
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
def load_chat_yamls(
db_session: Session,
prompt_yaml: str = PROMPTS_YAML,
personas_yaml: str = PERSONAS_YAML,
input_prompts_yaml: str = INPUT_PROMPT_YAML,
) -> None:
load_prompts_from_yaml(db_session, prompt_yaml)
load_personas_from_yaml(db_session, personas_yaml)
load_input_prompts_from_yaml(db_session, input_prompts_yaml)
load_prompts_from_yaml(prompt_yaml)
load_personas_from_yaml(personas_yaml)

View File

@@ -9,7 +9,6 @@ from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.tools.custom.base_tool_types import ToolResultType
class LlmDoc(BaseModel):
@@ -35,12 +34,11 @@ class QADocsResponse(RetrievalDocs):
applied_time_cutoff: datetime | None
recency_bias_multiplier: float
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore
initial_dict["applied_time_cutoff"] = (
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
)
return initial_dict
@@ -48,22 +46,15 @@ class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
class RelevanceAnalysis(BaseModel):
relevant: bool
class RelevanceChunk(BaseModel):
# TODO make this document level. Also slight misnomer here as this is actually
# done at the section level currently rather than the chunk
relevant: bool | None = None
content: str | None = None
class SectionRelevancePiece(RelevanceAnalysis):
"""LLM analysis mapped to an Inference Section"""
document_id: str
chunk_id: int # ID of the center chunk for a given inference section
class DocumentRelevance(BaseModel):
"""Contains all relevance information for a given search"""
relevance_summaries: dict[str, RelevanceAnalysis]
class LLMRelevanceSummaryResponse(BaseModel):
relevance_summaries: dict[str, RelevanceChunk]
class DanswerAnswerPiece(BaseModel):
@@ -78,14 +69,8 @@ class CitationInfo(BaseModel):
document_id: str
class MessageResponseIDInfo(BaseModel):
user_message_id: int | None
reserved_assistant_message_id: int
class StreamingError(BaseModel):
error: str
stack_trace: str | None = None
class DanswerQuote(BaseModel):
@@ -132,7 +117,7 @@ class ImageGenerationDisplay(BaseModel):
class CustomToolResponse(BaseModel):
response: ToolResultType
response: dict
tool_name: str

View File

@@ -5,7 +5,7 @@ personas:
# this is for DanswerBot to use when tagged in a non-configured channel
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
- id: 0
name: "Knowledge"
name: "Danswer"
description: >
Assistant with access to documents from your Connected Sources.
# Default Prompt objects attached to the persona, see prompts.yaml
@@ -17,7 +17,7 @@ personas:
num_chunks: 10
# Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine
# if the chunk is useful or not towards the latest user query
# This feature can be overriden for all personas via DISABLE_LLM_DOC_RELEVANCE env variable
# This feature can be overriden for all personas via DISABLE_LLM_CHUNK_FILTER env variable
llm_relevance_filter: true
# Enable/Disable usage of the LLM to extract query time filters including source type and time range filters
llm_filter_extraction: true
@@ -37,15 +37,12 @@ personas:
# - "Engineer Onboarding"
# - "Benefits"
document_sets: []
icon_shape: 23013
icon_color: "#6FB1FF"
display_priority: 1
is_visible: true
- id: 1
name: "General"
name: "GPT"
description: >
Assistant with no access to documents. Chat with just the Large Language Model.
Assistant with no access to documents. Chat with just the Language Model.
prompts:
- "OnlyLLM"
num_chunks: 0
@@ -53,10 +50,7 @@ personas:
llm_filter_extraction: true
recency_bias: "auto"
document_sets: []
icon_shape: 50910
icon_color: "#FF6F6F"
display_priority: 0
is_visible: true
- id: 2
name: "Paraphrase"
@@ -69,25 +63,3 @@ personas:
llm_filter_extraction: true
recency_bias: "auto"
document_sets: []
icon_shape: 45519
icon_color: "#6FFF8D"
display_priority: 2
is_visible: false
- id: 3
name: "Art"
description: >
Assistant for generating images based on descriptions.
prompts:
- "ImageGeneration"
num_chunks: 0
llm_relevance_filter: false
llm_filter_extraction: false
recency_bias: "no_decay"
document_sets: []
icon_shape: 234124
icon_color: "#9B59B6"
image_generation: true
display_priority: 3
is_visible: true

View File

@@ -1,4 +1,3 @@
import traceback
from collections.abc import Callable
from collections.abc import Iterator
from functools import partial
@@ -12,7 +11,6 @@ from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import BING_API_KEY
@@ -29,15 +27,15 @@ from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_db_search_doc_by_id
from danswer.db.chat import get_doc_query_identifiers_from_model
from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import FileDescriptor
@@ -53,9 +51,7 @@ from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
@@ -64,7 +60,6 @@ from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
@@ -183,7 +178,7 @@ def _handle_internet_search_tool_response_summary(
rephrased_query=internet_search_response.revised_query,
top_documents=response_docs,
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=SearchType.SEMANTIC,
predicted_search=SearchType.HYBRID,
applied_source_filters=[],
applied_time_cutoff=None,
recency_bias_multiplier=1.0,
@@ -192,46 +187,37 @@ def _handle_internet_search_tool_response_summary(
)
def _get_force_search_settings(
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
) -> ForceUseTool:
internet_search_available = any(
isinstance(tool, InternetSearchTool) for tool in tools
)
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
if not internet_search_available and not search_tool_available:
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
# Currently, the internet search tool does not support query override
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override and tool_name == SearchTool._NAME
else None
)
def _check_should_force_search(
new_msg_req: CreateChatMessageRequest,
) -> ForceUseTool | None:
# If files are already provided, don't run the search tool
if new_msg_req.file_descriptors:
# If user has uploaded files they're using, don't run any of the search tools
return ForceUseTool(force_use=False, tool_name=tool_name)
return None
should_force_search = any(
[
if (
new_msg_req.query_override
or (
new_msg_req.retrieval_options
and new_msg_req.retrieval_options.run_search
== OptionalSearchSetting.ALWAYS,
new_msg_req.search_doc_ids,
DISABLE_LLM_CHOOSE_SEARCH,
]
)
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
)
or new_msg_req.search_doc_ids
or DISABLE_LLM_CHOOSE_SEARCH
):
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override
else None
)
# if we are using selected docs, just put something here so the Tool doesn't need
# to build its own args via an LLM call
if new_msg_req.search_doc_ids:
args = {"query": new_msg_req.message}
if should_force_search:
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
return ForceUseTool(
tool_name=SearchTool._NAME,
args=args,
)
return None
ChatPacket = (
@@ -243,7 +229,6 @@ ChatPacket = (
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
| MessageResponseIDInfo
)
ChatPacketStream = Iterator[ChatPacket]
@@ -259,21 +244,17 @@ def stream_chat_message_objects(
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
# user message (e.g. this can only be used for the chat-seeding flow).
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
# Currently surrounding context is not supported for chat
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
new_msg_req.chunks_above = 0
new_msg_req.chunks_below = 0
"""
try:
user_id = user.id if user is not None else None
@@ -293,10 +274,7 @@ def stream_chat_message_objects(
# use alternate persona if alternative assistant id is passed in
if alternate_assistant_id is not None:
persona = get_persona_by_id(
alternate_assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
alternate_assistant_id, user=user, db_session=db_session
)
else:
persona = chat_session.persona
@@ -313,28 +291,20 @@ def stream_chat_message_objects(
try:
llm, fast_llm = get_llms_for_persona(
persona=persona,
db_session=db_session,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
)
except GenAIDisabledException:
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
llm_provider = llm.config.model_provider
llm_model_name = llm.config.model_name
llm_tokenizer = get_tokenizer(
model_name=llm_model_name,
provider_type=llm_provider,
)
llm_tokenizer = get_default_llm_tokenizer()
llm_tokenizer_encode_func = cast(
Callable[[str], list[int]], llm_tokenizer.encode
)
search_settings = get_current_search_settings(db_session)
embedding_model = get_current_db_embedding_model(db_session)
document_index = get_default_document_index(
primary_index_name=search_settings.index_name, secondary_index_name=None
primary_index_name=embedding_model.index_name, secondary_index_name=None
)
# Every chat Session begins with an empty root message
@@ -352,15 +322,7 @@ def stream_chat_message_objects(
parent_message = root_message
user_message = None
if new_msg_req.regenerate:
final_msg, history_msgs = create_chat_chain(
stop_at_message_id=parent_id,
chat_session_id=chat_session_id,
db_session=db_session,
)
elif not use_existing_user_message:
if not use_existing_user_message:
# Create new message at the right place in the tree and update the parent's child pointer
# Don't commit yet until we verify the chat message chain
user_message = create_new_chat_message(
@@ -399,14 +361,6 @@ def stream_chat_message_objects(
"when the last message is not a user message."
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
# leads to worst search quality
if not history_msgs:
new_msg_req.query_override = (
new_msg_req.query_override or new_msg_req.message
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
@@ -466,23 +420,9 @@ def stream_chat_message_objects(
else default_num_chunks
),
max_window_percentage=max_document_percentage,
use_sections=new_msg_req.chunks_above > 0
or new_msg_req.chunks_below > 0,
)
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
yield MessageResponseIDInfo(
user_message_id=user_message.id if user_message else None,
reserved_assistant_message_id=reserved_message_id,
)
overridden_model = (
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
)
# Cannot determine these without the LLM step or breaking out early
partial_response = partial(
@@ -490,7 +430,6 @@ def stream_chat_message_objects(
chat_session_id=chat_session_id,
parent_message=final_msg,
prompt_id=prompt_id,
overridden_model=overridden_model,
# message=,
# rephrased_query=,
# token_count=,
@@ -537,9 +476,6 @@ def stream_chat_message_objects(
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
evaluation_type=LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP,
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
@@ -608,16 +544,13 @@ def stream_chat_message_objects(
tools.extend(tool_list)
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
tools, llm_tokenizer
)
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(tools)
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm_provider, llm_model_name
llm.config.model_provider, llm.config.model_name
)
# LLM prompt building, response capturing, etc.
answer = Answer(
is_connected=is_connected,
question=final_msg.message,
latest_query_files=latest_query_files,
answer_style_config=AnswerStyleConfig(
@@ -632,7 +565,6 @@ def stream_chat_message_objects(
or get_main_llm_from_tuple(
get_llms_for_persona(
persona=persona,
db_session=db_session,
llm_override=(
new_msg_req.llm_override or chat_session.llm_override
),
@@ -644,7 +576,11 @@ def stream_chat_message_objects(
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
],
tools=tools,
force_use_tool=_get_force_search_settings(new_msg_req, tools),
force_use_tool=(
_check_should_force_search(new_msg_req)
if search_tool and len(tools) == 1
else None
),
)
reference_db_search_docs = None
@@ -652,7 +588,6 @@ def stream_chat_message_objects(
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
dropped_indices = None
tool_result = None
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
@@ -671,28 +606,18 @@ def stream_chat_message_objects(
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
chunk_indices = packet.response
if reference_db_search_docs is not None:
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in reference_db_search_docs
],
)
if dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
relevant_chunk_indices=llm_indices
if reference_db_search_docs is not None and dropped_indices:
chunk_indices = drop_llm_indices(
llm_indices=chunk_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
relevant_chunk_indices=chunk_indices
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
@@ -728,18 +653,20 @@ def stream_chat_message_objects(
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except Exception as e:
logger.exception("Failed to process chat message")
# Don't leak the API key
error_msg = str(e)
logger.exception(f"Failed to process chat message: {error_msg}")
if llm.config.api_key and llm.config.api_key.lower() in error_msg.lower():
error_msg = (
f"LLM failed to respond. Invalid API "
f"key error from '{llm.config.model_provider}'."
)
stack_trace = traceback.format_exc()
client_error_msg = litellm_exception_to_error_msg(e, llm)
if llm.config.api_key and len(llm.config.api_key) > 2:
error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]")
stack_trace = stack_trace.replace(llm.config.api_key, "[REDACTED_API_KEY]")
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
yield StreamingError(error=error_msg)
# Cancel the transaction so that no messages are saved
db_session.rollback()
return
@@ -759,7 +686,6 @@ def stream_chat_message_objects(
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
@@ -780,8 +706,6 @@ def stream_chat_message_objects(
if tool_result
else [],
)
logger.debug("Committing messages")
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
@@ -790,8 +714,7 @@ def stream_chat_message_objects(
yield msg_detail_response
except Exception as e:
error_msg = str(e)
logger.exception(error_msg)
logger.exception(e)
# Frontend will erase whatever answer and show this instead
yield StreamingError(error="Failed to parse LLM output")
@@ -801,19 +724,16 @@ def stream_chat_message_objects(
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
db_session: Session,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> Iterator[str]:
objects = stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
is_connected=is_connected,
)
for obj in objects:
yield get_json_line(obj.model_dump())
with get_session_context_manager() as db_session:
objects = stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
)
for obj in objects:
yield get_json_line(obj.dict())

View File

@@ -30,23 +30,7 @@ prompts:
# Prompts the LLM to include citations in the for [1], [2] etc.
# which get parsed to match the passed in sources
include_citations: true
- name: "ImageGeneration"
description: "Generates images based on user prompts!"
system: >
You are an advanced image generation system capable of creating diverse and detailed images.
You can interpret user prompts and generate high-quality, creative images that match their descriptions.
You always strive to create safe and appropriate content, avoiding any harmful or offensive imagery.
task: >
Generate an image based on the user's description.
Provide a detailed description of the generated image, including key elements, colors, and composition.
If the request is not possible or appropriate, explain why and suggest alternatives.
datetime_aware: true
include_citations: false
- name: "OnlyLLM"
description: "Chat directly with the LLM!"

View File

@@ -1,4 +1,4 @@
from typing_extensions import TypedDict # noreorder
from typing import TypedDict
from pydantic import BaseModel

View File

@@ -37,11 +37,9 @@ DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "
WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY") or "JWT_SECRET_KEY"
#####
# Auth Configs
#####
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
@@ -95,14 +93,6 @@ SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
# If set, Danswer will listen to the `expires_at` returned by the identity
# provider (e.g. Okta, Google, etc.) and force the user to re-authenticate
# after this time has elapsed. Disabled since by default many auth providers
# have very short expiry times (e.g. 1 hour) which provide a poor user experience
TRACK_EXTERNAL_IDP_EXPIRY = (
os.environ.get("TRACK_EXTERNAL_IDP_EXPIRY", "").lower() == "true"
)
#####
# DB Configs
@@ -136,20 +126,9 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
# defaults to False
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
# recycle timeout in seconds
POSTGRES_POOL_RECYCLE_DEFAULT = 60 * 20 # 20 minutes
try:
POSTGRES_POOL_RECYCLE = int(
os.environ.get("POSTGRES_POOL_RECYCLE", POSTGRES_POOL_RECYCLE_DEFAULT)
)
except ValueError:
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
#####
# Connector Configs
@@ -202,8 +181,8 @@ CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
]
# Avoid to get archived pages
CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES = (
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES", "").lower() == "true"
CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES = (
os.environ.get("CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES", "").lower() == "true"
)
# Save pages labels as Danswer metadata tags
@@ -212,16 +191,6 @@ CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = (
os.environ.get("CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING", "").lower() == "true"
)
# Attachments exceeding this size will not be retrieved (in bytes)
CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
# Attachments with more chars than this will not be indexed. This is to prevent extremely
# large files from freezing indexing. 200,000 is ~100 google doc pages.
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
)
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
@@ -243,11 +212,10 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
)
PRUNING_DISABLED = -1
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
ALLOW_SIMULTANEOUS_PRUNING = (
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
PREVENT_SIMULTANEOUS_PRUNING = (
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
)
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
@@ -280,39 +248,18 @@ DISABLE_INDEX_UPDATE_ON_SWAP = (
# fairly large amount of memory in order to increase substantially, since
# each worker loads the embedding models into memory.
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
NUM_SECONDARY_INDEXING_WORKERS = int(
os.environ.get("NUM_SECONDARY_INDEXING_WORKERS") or NUM_INDEXING_WORKERS
)
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
ENABLE_MULTIPASS_INDEXING = (
os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true"
)
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
# Finer grained chunking for more detail retention
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
MINI_CHUNK_SIZE = 150
# This is the number of regular chunks per large chunk
LARGE_CHUNK_RATIO = 4
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
# We don't want the metadata to overwhelm the actual contents of the chunk
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
# Timeout to wait for job's last update before killing it, in hours
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
# The indexer will warn in the logs whenver a document exceeds this threshold (in bytes)
INDEXING_SIZE_WARNING_THRESHOLD = int(
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD", 100 * 1024 * 1024)
)
# during indexing, will log verbose memory diff stats every x batches and at the end.
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
# During an indexing attempt, specifies the number of batches which are allowed to
# exception without aborting the attempt.
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
#####
# Miscellaneous
@@ -340,10 +287,6 @@ LOG_VESPA_TIMING_INFORMATION = (
os.environ.get("LOG_VESPA_TIMING_INFORMATION", "").lower() == "true"
)
LOG_ENDPOINT_LATENCY = os.environ.get("LOG_ENDPOINT_LATENCY", "").lower() == "true"
LOG_POSTGRES_LATENCY = os.environ.get("LOG_POSTGRES_LATENCY", "").lower() == "true"
LOG_POSTGRES_CONN_COUNTS = (
os.environ.get("LOG_POSTGRES_CONN_COUNTS", "").lower() == "true"
)
# Anonymous usage telemetry
DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true"
@@ -368,21 +311,3 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
ENTERPRISE_EDITION_ENABLED = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
###
# CLOUD CONFIGS
###
STRIPE_PRICE = os.environ.get("STRIPE_PRICE", "price_1PsYoPHlhTYqRZib2t5ydpq5")
STRIPE_WEBHOOK_SECRET = os.environ.get(
"STRIPE_WEBHOOK_SECRET",
"whsec_1cd766cd6bd08590aa8c46ab5c21ac32cad77c29de2e09a152a01971d6f405d3"
)
DEFAULT_SCHEMA = os.environ.get("DEFAULT_SCHEMA", "public")
DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "your_shared_secret_key")
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "your_control_plane_api_key")
MULTI_TENANT = os.environ.get("MULTI_TENANT", "false").lower() == "true"

View File

@@ -3,13 +3,12 @@ import os
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
PERSONAS_YAML = "./danswer/chat/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/chat/input_prompts.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking
# We want this to be approximately the number of results we want to show on the first page
# It cannot be too large due to cost and latency implications
NUM_POSTPROCESSED_RESULTS = 20
NUM_RERANKED_RESULTS = 20
# May be less depending on model
MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
@@ -31,9 +30,13 @@ FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
# Note this is not in any of the deployment configs yet
# Currently only applies to search flow not chat
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 1)
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 1)
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
# in relation to the user query
DISABLE_LLM_CHUNK_FILTER = (
os.environ.get("DISABLE_LLM_CHUNK_FILTER", "").lower() == "true"
)
# Whether the LLM should be used to decide if a search would help given the chat history
DISABLE_LLM_CHOOSE_SEARCH = (
os.environ.get("DISABLE_LLM_CHOOSE_SEARCH", "").lower() == "true"
@@ -44,19 +47,22 @@ DISABLE_LLM_QUERY_REPHRASE = (
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
# Keyword Search Drop Stopwords
# If user has changed the default model, would most likely be to use a multilingual
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
if os.environ.get("EDIT_KEYWORD_QUERY"):
EDIT_KEYWORD_QUERY = os.environ.get("EDIT_KEYWORD_QUERY", "").lower() == "true"
else:
EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL")
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.5)))
HYBRID_ALPHA_KEYWORD = max(
0, min(1, float(os.environ.get("HYBRID_ALPHA_KEYWORD") or 0.4))
)
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62)))
# Weighting factor between Title and Content of documents during search, 1 for completely
# Title based. Default heavily favors Content because Title is also included at the top of
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
# if the title is separated out. Title is most of a "boost" than a separate field.
TITLE_CONTENT_RATIO = max(
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.10))
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.20))
)
# A list of languages passed to the LLM to rephase the query
# For example "English,French,Spanish", be sure to use the "," separator
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
@@ -69,16 +75,16 @@ LANGUAGE_CHAT_NAMING_HINT = (
or "The name of the conversation must be in the same language as the user query."
)
# Agentic search takes significantly more tokens and therefore has much higher cost.
# This configuration allows users to get a search-only experience with instant results
# and no involvement from the LLM.
# Additionally, some LLM providers have strict rate limits which may prohibit
# sending many API requests at once (as is done in agentic search).
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
# in relation to the user query
DISABLE_LLM_DOC_RELEVANCE = (
os.environ.get("DISABLE_LLM_DOC_RELEVANCE", "").lower() == "true"
)
DISABLE_AGENTIC_SEARCH = (
os.environ.get("DISABLE_AGENTIC_SEARCH") or "false"
).lower() == "true"
# Stops streaming answers back to the UI if this pattern is seen:
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
@@ -88,4 +94,3 @@ HARD_DELETE_CHATS = False
# Internet Search
BING_API_KEY = os.environ.get("BING_API_KEY") or None
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)

View File

@@ -1,7 +1,26 @@
from enum import auto
from enum import Enum
DOCUMENT_ID = "document_id"
CHUNK_ID = "chunk_id"
BLURB = "blurb"
CONTENT = "content"
SOURCE_TYPE = "source_type"
SOURCE_LINKS = "source_links"
SOURCE_LINK = "link"
SEMANTIC_IDENTIFIER = "semantic_identifier"
TITLE = "title"
SKIP_TITLE_EMBEDDING = "skip_title"
SECTION_CONTINUATION = "section_continuation"
EMBEDDINGS = "embeddings"
TITLE_EMBEDDING = "title_embedding"
ALLOWED_USERS = "allowed_users"
ACCESS_CONTROL_LIST = "access_control_list"
DOCUMENT_SETS = "document_sets"
TIME_FILTER = "time_filter"
METADATA = "metadata"
METADATA_LIST = "metadata_list"
METADATA_SUFFIX = "metadata_suffix"
MATCH_HIGHLIGHTS = "match_highlights"
# stored in the `metadata` of a chunk. Used to signify that this chunk should
# not be used for QA. For example, Google Drive file types which can't be parsed
# are still useful as a search result but not for QA.
@@ -9,11 +28,23 @@ IGNORE_FOR_QA = "ignore_for_qa"
# NOTE: deprecated, only used for porting key from old system
GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
PUBLIC_DOC_PAT = "PUBLIC"
PUBLIC_DOCUMENT_SET = "__PUBLIC"
QUOTE = "quote"
BOOST = "boost"
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
PRIMARY_OWNERS = "primary_owners"
SECONDARY_OWNERS = "secondary_owners"
RECENCY_BIAS = "recency_bias"
HIDDEN = "hidden"
SCORE = "score"
ID_SEPARATOR = ":;:"
DEFAULT_BOOST = 0
SESSION_KEY = "session"
QUERY_EVENT_ID = "query_event_id"
LLM_CHUNKS = "llm_chunks"
# For chunking/processing chunks
MAX_CHUNK_TITLE_LEN = 1000
RETURN_SEPARATOR = "\n\r\n"
SECTION_SEPARATOR = "\n\n"
# For combining attributes, doesn't have to be unique/perfect to work
@@ -29,37 +60,12 @@ DISABLED_GEN_AI_MSG = (
"You can still use Danswer as a search engine."
)
# Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
POSTGRES_CELERY_APP_NAME = "celery"
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
# API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__"
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "danswerapikey.ai"
UNNAMED_KEY_PLACEHOLDER = "Unnamed"
# Key-Value store keys
KV_REINDEX_KEY = "needs_reindexing"
KV_SEARCH_SETTINGS = "search_settings"
KV_USER_STORE_KEY = "INVITED_USERS"
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
KV_CRED_KEY = "credential_id_{}"
KV_GMAIL_CRED_KEY = "gmail_app_credential"
KV_GMAIL_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
KV_GOOGLE_DRIVE_CRED_KEY = "google_drive_app_credential"
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
KV_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time"
KV_SETTINGS_KEY = "danswer_settings"
KV_CUSTOMER_UUID_KEY = "customer_uuid"
KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
@@ -103,10 +109,6 @@ class DocumentSource(str, Enum):
NOT_APPLICABLE = "not_applicable"
class NotificationType(str, Enum):
REINDEX = "reindex"
class BlobType(str, Enum):
R2 = "r2"
S3 = "s3"
@@ -162,7 +164,3 @@ class FileOrigin(str, Enum):
CONNECTOR = "connector"
GENERATED_REPORT = "generated_report"
OTHER = "other"
class PostgresAdvisoryLocks(Enum):
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()

View File

@@ -12,15 +12,13 @@ import os
# The useable models configured as below must be SentenceTransformer compatible
# NOTE: DO NOT CHANGE SET THESE UNLESS YOU KNOW WHAT YOU ARE DOING
# IDEALLY, YOU SHOULD CHANGE EMBEDDING MODELS VIA THE UI
DEFAULT_DOCUMENT_ENCODER_MODEL = "nomic-ai/nomic-embed-text-v1"
DEFAULT_DOCUMENT_ENCODER_MODEL = "intfloat/e5-base-v2"
DOCUMENT_ENCODER_MODEL = (
os.environ.get("DOCUMENT_ENCODER_MODEL") or DEFAULT_DOCUMENT_ENCODER_MODEL
)
# If the below is changed, Vespa deployment must also be changed
DOC_EMBEDDING_DIM = int(os.environ.get("DOC_EMBEDDING_DIM") or 768)
# Model should be chosen with 512 context size, ideally don't change this
# If multipass_indexing is enabled, the max context size would be set to
# DOC_EMBEDDING_CONTEXT_SIZE * LARGE_CHUNK_RATIO
DOC_EMBEDDING_CONTEXT_SIZE = 512
NORMALIZE_EMBEDDINGS = (
os.environ.get("NORMALIZE_EMBEDDINGS") or "true"
@@ -36,16 +34,17 @@ OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS = False
SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0)
SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
# Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs)
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ")
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "query: ")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
# don't send over too many chunks at once, as sending too many could cause timeouts
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = 512
# For score display purposes, only way is to know the expected ranges
CROSS_ENCODER_RANGE_MAX = 1
CROSS_ENCODER_RANGE_MIN = 0
# Unused currently, can't be used with the current default encoder model due to its output range
SEARCH_DISTANCE_CUTOFF = 0
#####
# Generative AI Model Configs
@@ -82,9 +81,6 @@ GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
# Set this to be enough for an answer + quotes. Also used for Chat
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
# Typically, GenAI models nowadays are at least 4K tokens
GEN_AI_MODEL_DEFAULT_MAX_TOKENS = 4096
# Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible
# as this drives up the cost unnecessarily

View File

@@ -56,7 +56,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
Raises ValueError for unsupported bucket types.
"""
logger.debug(
logger.info(
f"Loading credentials for {self.bucket_name} or type {self.bucket_type}"
)
@@ -169,7 +169,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
end: datetime,
) -> GenerateDocumentsOutput:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blob storage")
raise ConnectorMissingCredentialError("Blog storage")
paginator = self.s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
@@ -220,7 +220,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
yield batch
def load_from_state(self) -> GenerateDocumentsOutput:
logger.debug("Loading blob objects")
logger.info("Loading blob objects")
return self._yield_blob_objects(
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
end=datetime.now(timezone.utc),
@@ -230,7 +230,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blob storage")
raise ConnectorMissingCredentialError("Blog storage")
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)

View File

@@ -13,11 +13,7 @@ import bs4
from atlassian import Confluence # type:ignore
from requests import HTTPError
from danswer.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
@@ -45,14 +41,6 @@ logger = setup_logger()
# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
NO_PERMISSIONS_TO_VIEW_ATTACHMENTS_ERROR_STR = (
"User not permitted to view attachments on content"
)
NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR = (
"No parent or not permitted to view content with id"
)
def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]:
"""Sample
URL w/ page: https://danswer.atlassian.net/wiki/spaces/1234abcd/pages/5678efgh/overview
@@ -211,55 +199,34 @@ def _comment_dfs(
comments_str += "\nComment:\n" + parse_html_page(
comment_html, confluence_client
)
try:
child_comment_pages = get_page_child_by_type(
comment_page["id"],
type="comment",
start=None,
limit=None,
expand="body.storage.value",
)
comments_str = _comment_dfs(
comments_str, child_comment_pages, confluence_client
)
except HTTPError as e:
# not the cleanest, but I'm not aware of a nicer way to check the error
if NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR not in str(e):
raise
child_comment_pages = get_page_child_by_type(
comment_page["id"],
type="comment",
start=None,
limit=None,
expand="body.storage.value",
)
comments_str = _comment_dfs(
comments_str, child_comment_pages, confluence_client
)
return comments_str
def _datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime.fromisoformat(datetime_string)
if datetime_object.tzinfo is None:
# If no timezone info, assume it is UTC
datetime_object = datetime_object.replace(tzinfo=timezone.utc)
else:
# If not in UTC, translate it
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
class RecursiveIndexer:
def __init__(
self,
batch_size: int,
confluence_client: Confluence,
index_recursively: bool,
index_origin: bool,
origin_page_id: str,
) -> None:
self.batch_size = 1
# batch_size
self.confluence_client = confluence_client
self.index_recursively = index_recursively
self.index_origin = index_origin
self.origin_page_id = origin_page_id
self.pages = self.recurse_children_pages(0, self.origin_page_id)
def get_origin_page(self) -> list[dict[str, Any]]:
return [self._fetch_origin_page()]
def get_pages(self, ind: int, size: int) -> list[dict]:
if ind * size > len(self.pages):
return []
@@ -315,11 +282,12 @@ class RecursiveIndexer:
current_level_pages = next_level_pages
next_level_pages = []
try:
origin_page = self._fetch_origin_page()
pages.append(origin_page)
except Exception as e:
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
if self.index_origin:
try:
origin_page = self._fetch_origin_page()
pages.append(origin_page)
except Exception as e:
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
return pages
@@ -372,7 +340,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
def __init__(
self,
wiki_page_url: str,
index_recursively: bool = True,
index_origin: bool = True,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
# if a page has one of the labels specified in this list, we will just
@@ -384,7 +352,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self.continue_on_failure = continue_on_failure
self.labels_to_skip = set(labels_to_skip)
self.recursive_indexer: RecursiveIndexer | None = None
self.index_recursively = index_recursively
self.index_origin = index_origin
(
self.wiki_base,
self.space,
@@ -401,7 +369,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
logger.info(
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively}"
+ f" space_level_scan: {self.space_level_scan}, origin: {self.index_origin}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
@@ -432,7 +400,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
start=start_ind,
limit=batch_size,
status=(
None if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES else "current"
"current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None
),
expand="body.storage.value,version",
)
@@ -453,9 +423,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
start=start_ind + i,
limit=1,
status=(
None
if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
else "current"
"current"
if CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
else None
),
expand="body.storage.value,version",
)
@@ -483,13 +453,10 @@ class ConfluenceConnector(LoadConnector, PollConnector):
origin_page_id=self.page_id,
batch_size=self.batch_size,
confluence_client=self.confluence_client,
index_recursively=self.index_recursively,
index_origin=self.index_origin,
)
if self.index_recursively:
return self.recursive_indexer.get_pages(start_ind, batch_size)
else:
return self.recursive_indexer.get_origin_page()
return self.recursive_indexer.get_pages(start_ind, batch_size)
pages: list[dict[str, Any]] = []
@@ -562,249 +529,134 @@ class ConfluenceConnector(LoadConnector, PollConnector):
logger.exception("Ran into exception when fetching labels from Confluence")
return []
@classmethod
def _attachment_to_download_link(
cls, confluence_client: Confluence, attachment: dict[str, Any]
) -> str:
return confluence_client.url + attachment["_links"]["download"]
@classmethod
def _attachment_to_content(
cls,
confluence_client: Confluence,
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if attachment["metadata"]["mediaType"] in [
"image/jpeg",
"image/png",
"image/gif",
"image/svg+xml",
"video/mp4",
"video/quicktime",
]:
return None
download_link = cls._attachment_to_download_link(confluence_client, attachment)
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return None
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
)
return None
extracted_text = extract_file_text(
attachment["title"], io.BytesIO(response.content), False
)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
)
return None
return extracted_text
def _fetch_attachments(
self, confluence_client: Confluence, page_id: str, files_in_used: list[str]
) -> tuple[str, list[dict[str, Any]]]:
unused_attachments: list = []
) -> str:
get_attachments_from_content = make_confluence_call_handle_rate_limit(
confluence_client.get_attachments_from_content
)
files_attachment_content: list = []
try:
expand = "history.lastUpdated,metadata.labels"
attachments_container = get_attachments_from_content(
page_id, start=0, limit=500, expand=expand
page_id, start=0, limit=500
)
for attachment in attachments_container["results"]:
if attachment["title"] not in files_in_used:
unused_attachments.append(attachment)
if attachment["metadata"]["mediaType"] in [
"image/jpeg",
"image/png",
"image/gif",
"image/svg+xml",
"video/mp4",
"video/quicktime",
]:
continue
attachment_content = self._attachment_to_content(
confluence_client, attachment
)
if attachment_content:
files_attachment_content.append(attachment_content)
if attachment["title"] not in files_in_used:
continue
download_link = confluence_client.url + attachment["_links"]["download"]
response = confluence_client._session.get(download_link)
if response.status_code == 200:
extract = extract_file_text(
attachment["title"], io.BytesIO(response.content), False
)
files_attachment_content.append(extract)
except Exception as e:
if isinstance(
e, HTTPError
) and NO_PERMISSIONS_TO_VIEW_ATTACHMENTS_ERROR_STR in str(e):
logger.warning(
f"User does not have access to attachments on page '{page_id}'"
)
return "", []
if not self.continue_on_failure:
raise e
logger.exception(
f"Ran into exception when fetching attachments from Confluence: {e}"
)
return "\n".join(files_attachment_content), unused_attachments
return "\n".join(files_attachment_content)
def _get_doc_batch(
self, start_ind: int, time_filter: Callable[[datetime], bool] | None = None
) -> tuple[list[Document], list[dict[str, Any]], int]:
) -> tuple[list[Document], int]:
doc_batch: list[Document] = []
unused_attachments: list[dict[str, Any]] = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
batch = self._fetch_pages(self.confluence_client, start_ind)
for page in batch:
last_modified = _datetime_from_string(page["version"]["when"])
last_modified_str = page["version"]["when"]
author = cast(str | None, page["version"].get("by", {}).get("email"))
last_modified = datetime.fromisoformat(last_modified_str)
if time_filter and not time_filter(last_modified):
continue
if last_modified.tzinfo is None:
# If no timezone info, assume it is UTC
last_modified = last_modified.replace(tzinfo=timezone.utc)
else:
# If not in UTC, translate it
last_modified = last_modified.astimezone(timezone.utc)
page_id = page["id"]
if time_filter is None or time_filter(last_modified):
page_id = page["id"]
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
page_labels = self._fetch_labels(self.confluence_client, page_id)
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
page_labels = self._fetch_labels(self.confluence_client, page_id)
# check disallowed labels
if self.labels_to_skip:
label_intersection = self.labels_to_skip.intersection(page_labels)
if label_intersection:
logger.info(
f"Page with ID '{page_id}' has a label which has been "
f"designated as disallowed: {label_intersection}. Skipping."
)
# check disallowed labels
if self.labels_to_skip:
label_intersection = self.labels_to_skip.intersection(page_labels)
if label_intersection:
logger.info(
f"Page with ID '{page_id}' has a label which has been "
f"designated as disallowed: {label_intersection}. Skipping."
)
continue
page_html = (
page["body"]
.get("storage", page["body"].get("view", {}))
.get("value")
)
page_url = self.wiki_base + page["_links"]["webui"]
if not page_html:
logger.debug("Page is empty, skipping: %s", page_url)
continue
page_text = parse_html_page(page_html, self.confluence_client)
page_html = (
page["body"].get("storage", page["body"].get("view", {})).get("value")
)
page_url = self.wiki_base + page["_links"]["webui"]
if not page_html:
logger.debug("Page is empty, skipping: %s", page_url)
continue
page_text = parse_html_page(page_html, self.confluence_client)
files_in_used = get_used_attachments(page_html, self.confluence_client)
attachment_text, unused_page_attachments = self._fetch_attachments(
self.confluence_client, page_id, files_in_used
)
unused_attachments.extend(unused_page_attachments)
page_text += attachment_text
comments_text = self._fetch_comments(self.confluence_client, page_id)
page_text += comments_text
doc_metadata: dict[str, str | list[str]] = {"Wiki Space Name": self.space}
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
doc_metadata["labels"] = page_labels
doc_batch.append(
Document(
id=page_url,
sections=[Section(link=page_url, text=page_text)],
source=DocumentSource.CONFLUENCE,
semantic_identifier=page["title"],
doc_updated_at=last_modified,
primary_owners=(
[BasicExpertInfo(email=author)] if author else None
),
metadata=doc_metadata,
files_in_used = get_used_attachments(page_html, self.confluence_client)
attachment_text = self._fetch_attachments(
self.confluence_client, page_id, files_in_used
)
)
return (
doc_batch,
unused_attachments,
len(batch),
)
page_text += attachment_text
comments_text = self._fetch_comments(self.confluence_client, page_id)
page_text += comments_text
doc_metadata: dict[str, str | list[str]] = {
"Wiki Space Name": self.space
}
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
doc_metadata["labels"] = page_labels
def _get_attachment_batch(
self,
start_ind: int,
attachments: list[dict[str, Any]],
time_filter: Callable[[datetime], bool] | None = None,
) -> tuple[list[Document], int]:
doc_batch: list[Document] = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
end_ind = min(start_ind + self.batch_size, len(attachments))
for attachment in attachments[start_ind:end_ind]:
last_updated = _datetime_from_string(
attachment["history"]["lastUpdated"]["when"]
)
if time_filter and not time_filter(last_updated):
continue
attachment_url = self._attachment_to_download_link(
self.confluence_client, attachment
)
attachment_content = self._attachment_to_content(
self.confluence_client, attachment
)
if attachment_content is None:
continue
creator_email = attachment["history"]["createdBy"].get("email")
comment = attachment["metadata"].get("comment", "")
doc_metadata: dict[str, str | list[str]] = {"comment": comment}
attachment_labels: list[str] = []
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
for label in attachment["metadata"]["labels"]["results"]:
attachment_labels.append(label["name"])
doc_metadata["labels"] = attachment_labels
doc_batch.append(
Document(
id=attachment_url,
sections=[Section(link=attachment_url, text=attachment_content)],
source=DocumentSource.CONFLUENCE,
semantic_identifier=attachment["title"],
doc_updated_at=last_updated,
primary_owners=(
[BasicExpertInfo(email=creator_email)]
if creator_email
else None
),
metadata=doc_metadata,
doc_batch.append(
Document(
id=page_url,
sections=[Section(link=page_url, text=page_text)],
source=DocumentSource.CONFLUENCE,
semantic_identifier=page["title"],
doc_updated_at=last_modified,
primary_owners=(
[BasicExpertInfo(email=author)] if author else None
),
metadata=doc_metadata,
)
)
)
return doc_batch, end_ind - start_ind
return doc_batch, len(batch)
def load_from_state(self) -> GenerateDocumentsOutput:
unused_attachments = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
start_ind = 0
while True:
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
start_ind
)
unused_attachments.extend(unused_attachments_batch)
doc_batch, num_pages = self._get_doc_batch(start_ind)
start_ind += num_pages
if doc_batch:
yield doc_batch
@@ -812,23 +664,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if num_pages < self.batch_size:
break
start_ind = 0
while True:
attachment_batch, num_attachments = self._get_attachment_batch(
start_ind, unused_attachments
)
start_ind += num_attachments
if attachment_batch:
yield attachment_batch
if num_attachments < self.batch_size:
break
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
unused_attachments = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
@@ -837,11 +675,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
start_ind = 0
while True:
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
doc_batch, num_pages = self._get_doc_batch(
start_ind, time_filter=lambda t: start_time <= t <= end_time
)
unused_attachments.extend(unused_attachments_batch)
start_ind += num_pages
if doc_batch:
yield doc_batch
@@ -849,20 +685,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if num_pages < self.batch_size:
break
start_ind = 0
while True:
attachment_batch, num_attachments = self._get_attachment_batch(
start_ind,
unused_attachments,
time_filter=lambda t: start_time <= t <= end_time,
)
start_ind += num_attachments
if attachment_batch:
yield attachment_batch
if num_attachments < self.batch_size:
break
if __name__ == "__main__":
connector = ConfluenceConnector(os.environ["CONFLUENCE_TEST_SPACE_URL"])

View File

@@ -23,12 +23,11 @@ class ConfluenceRateLimitError(Exception):
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
max_retries = 10
starting_delay = 5
backoff = 2
max_delay = 600
for attempt in range(max_retries):
for attempt in range(10):
try:
return confluence_call(*args, **kwargs)
except HTTPError as e:
@@ -56,14 +55,5 @@ def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
else:
# re-raise, let caller handle
raise
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
logger.warning(f"Confluence Internal Error, retrying... {e}")
delay = min(starting_delay * (backoff**attempt), max_delay)
time.sleep(delay)
if attempt == max_retries - 1:
raise e
return cast(F, wrapped_call)

View File

@@ -1,70 +0,0 @@
import sys
from datetime import datetime
from danswer.connectors.interfaces import BaseConnector
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.utils.logger import setup_logger
logger = setup_logger()
TimeRange = tuple[datetime, datetime]
class ConnectorRunner:
def __init__(
self,
connector: BaseConnector,
time_range: TimeRange | None = None,
fail_loudly: bool = False,
):
self.connector = connector
if isinstance(self.connector, PollConnector):
if time_range is None:
raise ValueError("time_range is required for PollConnector")
self.doc_batch_generator = self.connector.poll_source(
time_range[0].timestamp(), time_range[1].timestamp()
)
elif isinstance(self.connector, LoadConnector):
if time_range and fail_loudly:
raise ValueError(
"time_range specified, but passed in connector is not a PollConnector"
)
self.doc_batch_generator = self.connector.load_from_state()
else:
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
def run(self) -> GenerateDocumentsOutput:
"""Adds additional exception logging to the connector."""
try:
yield from self.doc_batch_generator
except Exception:
exc_type, _, exc_traceback = sys.exc_info()
# Traverse the traceback to find the last frame where the exception was raised
tb = exc_traceback
if tb is None:
logger.error("No traceback found for exception")
raise
while tb.tb_next:
tb = tb.tb_next # Move to the next frame in the traceback
# Get the local variables from the frame where the exception occurred
local_vars = tb.tb_frame.f_locals
local_vars_str = "\n".join(
f"{key}: {value}" for key, value in local_vars.items()
)
logger.error(
f"Error in connector. type: {exc_type};\n"
f"local_vars below -> \n{local_vars_str}"
)
raise

View File

@@ -56,7 +56,7 @@ class _RateLimitDecorator:
sleep_cnt = 0
while len(self.call_history) == self.max_calls:
sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt)
logger.notice(
logger.info(
f"Rate limit exceeded for function {func.__name__}. "
f"Waiting {sleep_time} seconds before retrying."
)

View File

@@ -56,16 +56,6 @@ def extract_text_from_content(content: dict) -> str:
return " ".join(texts)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def _get_comment_strs(
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
@@ -127,10 +117,8 @@ def fetch_jira_issues_batch(
continue
comments = _get_comment_strs(jira, comment_email_blacklist)
semantic_rep = (
f"{jira.fields.description}\n"
if jira.fields.description
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments]
)
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
@@ -159,18 +147,14 @@ def fetch_jira_issues_batch(
pass
metadata_dict = {}
priority = best_effort_get_field_from_issue(jira, "priority")
if priority:
metadata_dict["priority"] = priority.name
status = best_effort_get_field_from_issue(jira, "status")
if status:
metadata_dict["status"] = status.name
resolution = best_effort_get_field_from_issue(jira, "resolution")
if resolution:
metadata_dict["resolution"] = resolution.name
labels = best_effort_get_field_from_issue(jira, "labels")
if labels:
metadata_dict["label"] = labels
if jira.fields.priority:
metadata_dict["priority"] = jira.fields.priority.name
if jira.fields.status:
metadata_dict["status"] = jira.fields.status.name
if jira.fields.resolution:
metadata_dict["resolution"] = jira.fields.resolution.name
if jira.fields.labels:
metadata_dict["label"] = jira.fields.labels
doc_batch.append(
Document(

View File

@@ -64,7 +64,7 @@ class DiscourseConnector(PollConnector):
self.permissions: DiscoursePerms | None = None
self.active_categories: set | None = None
@rate_limit_builder(max_calls=50, period=60)
@rate_limit_builder(max_calls=100, period=60)
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
if not self.permissions:
raise ConnectorMissingCredentialError("Discourse")

View File

@@ -159,12 +159,10 @@ class LocalFileConnector(LoadConnector):
self,
file_locations: list[Path | str],
batch_size: int = INDEX_BATCH_SIZE,
tenant_id: str | None = None
) -> None:
self.file_locations = [Path(file_location) for file_location in file_locations]
self.batch_size = batch_size
self.pdf_pass: str | None = None
self.tenant_id = tenant_id
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.pdf_pass = credentials.get("pdf_password")
@@ -172,7 +170,7 @@ class LocalFileConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
with Session(get_sqlalchemy_engine(schema=self.tenant_id)) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata(

View File

@@ -38,7 +38,7 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
tzinfo=timezone.utc
) - datetime.now(tz=timezone.utc)
sleep_time += timedelta(minutes=1) # add an extra minute just to be safe
logger.notice(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
logger.info(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
time.sleep(sleep_time.seconds)

View File

@@ -11,17 +11,16 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GMAIL_CRED_KEY
from danswer.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY
from danswer.connectors.gmail.constants import CRED_KEY
from danswer.connectors.gmail.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.gmail.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.gmail.constants import GMAIL_CRED_KEY
from danswer.connectors.gmail.constants import (
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.gmail.constants import GMAIL_SERVICE_ACCOUNT_KEY
from danswer.connectors.gmail.constants import SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
@@ -50,7 +49,7 @@ def get_gmail_creds_for_authorized_user(
try:
creds.refresh(Request())
if creds.valid:
logger.notice("Refreshed Gmail tokens.")
logger.info("Refreshed Gmail tokens.")
return creds
except Exception as e:
logger.exception(f"Failed to refresh gmail access token due to: {e}")
@@ -72,7 +71,7 @@ def get_gmail_creds_for_service_account(
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
csrf = get_dynamic_config_store().load(CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Gmail Connector callback does not match expected"
@@ -80,7 +79,7 @@ def verify_csrf(credential_id: int, state: str) -> None:
def get_gmail_auth_url(credential_id: int) -> str:
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
creds_str = str(get_dynamic_config_store().load(GMAIL_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
@@ -92,14 +91,12 @@ def get_gmail_auth_url(credential_id: int) -> str:
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_dynamic_config_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
get_dynamic_config_store().store(CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True) # type: ignore
return str(auth_url)
def get_auth_url(credential_id: int) -> str:
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
creds_str = str(get_dynamic_config_store().load(GMAIL_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
@@ -111,9 +108,7 @@ def get_auth_url(credential_id: int) -> str:
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_dynamic_config_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
get_dynamic_config_store().store(CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True) # type: ignore
return str(auth_url)
@@ -125,7 +120,7 @@ def update_gmail_credential_access_tokens(
) -> OAuthCredentials | None:
app_credentials = get_google_app_gmail_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
app_credentials.dict(),
scopes=SCOPES,
redirect_uri=_build_frontend_gmail_redirect(),
)
@@ -151,29 +146,28 @@ def build_service_account_creds(
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
return CredentialBase(
source=DocumentSource.GMAIL,
credential_json=credential_dict,
admin_public=True,
)
def get_google_app_gmail_cred() -> GoogleAppCredentials:
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
creds_str = str(get_dynamic_config_store().load(GMAIL_CRED_KEY))
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_gmail_cred(app_credentials: GoogleAppCredentials) -> None:
get_dynamic_config_store().store(
KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True
GMAIL_CRED_KEY, app_credentials.json(), encrypt=True
)
def delete_google_app_gmail_cred() -> None:
get_dynamic_config_store().delete(KV_GMAIL_CRED_KEY)
get_dynamic_config_store().delete(GMAIL_CRED_KEY)
def get_gmail_service_account_key() -> GoogleServiceAccountKey:
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
creds_str = str(get_dynamic_config_store().load(GMAIL_SERVICE_ACCOUNT_KEY))
return GoogleServiceAccountKey(**json.loads(creds_str))
@@ -181,19 +175,19 @@ def upsert_gmail_service_account_key(
service_account_key: GoogleServiceAccountKey,
) -> None:
get_dynamic_config_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
get_dynamic_config_store().store(
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def delete_gmail_service_account_key() -> None:
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
get_dynamic_config_store().delete(GMAIL_SERVICE_ACCOUNT_KEY)
def delete_service_account_key() -> None:
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
get_dynamic_config_store().delete(GMAIL_SERVICE_ACCOUNT_KEY)

View File

@@ -1,4 +1,7 @@
DB_CREDENTIALS_DICT_TOKEN_KEY = "gmail_tokens"
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "gmail_delegated_user"
CRED_KEY = "credential_id_{}"
GMAIL_CRED_KEY = "gmail_app_credential"
GMAIL_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]

View File

@@ -81,10 +81,10 @@ class GongConnector(LoadConnector, PollConnector):
for workspace in workspace_list:
if workspace:
logger.info(f"Updating Gong workspace: {workspace}")
logger.info(f"Updating workspace: {workspace}")
workspace_id = workspace_map.get(workspace)
if not workspace_id:
logger.error(f"Invalid Gong workspace: {workspace}")
logger.error(f"Invalid workspace: {workspace}")
if not self.continue_on_fail:
raise ValueError(f"Invalid workspace: {workspace}")
continue

View File

@@ -267,7 +267,7 @@ def get_all_files_batched(
yield from batch_generator(
items=found_files,
batch_size=batch_size,
pre_batch_yield=lambda batch_files: logger.debug(
pre_batch_yield=lambda batch_files: logger.info(
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
),
)
@@ -306,29 +306,24 @@ def get_all_files_batched(
def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
mime_type = file["mimeType"]
if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return UNSUPPORTED_FILE_TYPE_CONTENT
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = "text/plain"
if mime_type == GDriveMimeType.SPREADSHEET.value:
export_mime_type = "text/csv"
elif mime_type == GDriveMimeType.PPT.value:
export_mime_type = "text/plain"
response = (
if mime_type == GDriveMimeType.DOC.value:
return (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.export(fileId=file["id"], mimeType="text/plain")
.execute()
.decode("utf-8")
)
elif mime_type == GDriveMimeType.SPREADSHEET.value:
return (
service.files()
.export(fileId=file["id"], mimeType="text/csv")
.execute()
.decode("utf-8")
)
return response.decode("utf-8")
elif mime_type == GDriveMimeType.WORD_DOC.value:
response = service.files().get_media(fileId=file["id"]).execute()
return docx_to_text(file=io.BytesIO(response))
@@ -338,6 +333,9 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
elif mime_type == GDriveMimeType.POWERPOINT.value:
response = service.files().get_media(fileId=file["id"]).execute()
return pptx_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.PPT.value:
response = service.files().get_media(fileId=file["id"]).execute()
return pptx_to_text(file=io.BytesIO(response))
return UNSUPPORTED_FILE_TYPE_CONTENT

View File

@@ -11,10 +11,7 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_drive.constants import CRED_KEY
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
@@ -22,6 +19,8 @@ from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.google_drive.constants import GOOGLE_DRIVE_CRED_KEY
from danswer.connectors.google_drive.constants import GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_drive.constants import SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
@@ -50,7 +49,7 @@ def get_google_drive_creds_for_authorized_user(
try:
creds.refresh(Request())
if creds.valid:
logger.notice("Refreshed Google Drive tokens.")
logger.info("Refreshed Google Drive tokens.")
return creds
except Exception as e:
logger.exception(f"Failed to refresh google drive access token due to: {e}")
@@ -72,7 +71,7 @@ def get_google_drive_creds_for_service_account(
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
csrf = get_dynamic_config_store().load(CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Google Drive Connector callback does not match expected"
@@ -80,7 +79,7 @@ def verify_csrf(credential_id: int, state: str) -> None:
def get_auth_url(credential_id: int) -> str:
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
creds_str = str(get_dynamic_config_store().load(GOOGLE_DRIVE_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
@@ -92,9 +91,7 @@ def get_auth_url(credential_id: int) -> str:
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_dynamic_config_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
get_dynamic_config_store().store(CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True) # type: ignore
return str(auth_url)
@@ -106,7 +103,7 @@ def update_credential_access_tokens(
) -> OAuthCredentials | None:
app_credentials = get_google_app_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
app_credentials.dict(),
scopes=SCOPES,
redirect_uri=_build_frontend_google_drive_redirect(),
)
@@ -121,7 +118,6 @@ def update_credential_access_tokens(
def build_service_account_creds(
source: DocumentSource,
delegated_user_email: str | None = None,
) -> CredentialBase:
service_account_key = get_service_account_key()
@@ -135,37 +131,34 @@ def build_service_account_creds(
return CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
)
def get_google_app_cred() -> GoogleAppCredentials:
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
creds_str = str(get_dynamic_config_store().load(GOOGLE_DRIVE_CRED_KEY))
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_cred(app_credentials: GoogleAppCredentials) -> None:
get_dynamic_config_store().store(
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
)
def delete_google_app_cred() -> None:
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
get_dynamic_config_store().delete(GOOGLE_DRIVE_CRED_KEY)
def get_service_account_key() -> GoogleServiceAccountKey:
creds_str = str(
get_dynamic_config_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
)
creds_str = str(get_dynamic_config_store().load(GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
get_dynamic_config_store().store(
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def delete_service_account_key() -> None:
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
get_dynamic_config_store().delete(GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)

View File

@@ -1,6 +1,9 @@
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
CRED_KEY = "credential_id_{}"
GOOGLE_DRIVE_CRED_KEY = "google_drive_app_credential"
GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
SCOPES = [
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly",

View File

@@ -103,10 +103,6 @@ class GuruConnector(LoadConnector, PollConnector):
# In UI it's called Folders
metadata_dict["folders"] = boards
collection = card.get("collection", {})
if collection:
metadata_dict["collection_name"] = collection.get("name", "")
owner = card.get("owner", {})
author = None
if owner:

View File

@@ -86,6 +86,7 @@ class MediaWikiConnector(LoadConnector, PollConnector):
categories: The categories to include in the index.
pages: The pages to include in the index.
recurse_depth: The depth to recurse into categories. -1 means unbounded recursion.
connector_name: The name of the connector.
language_code: The language code of the wiki.
batch_size: The batch size for loading documents.
@@ -103,6 +104,7 @@ class MediaWikiConnector(LoadConnector, PollConnector):
categories: list[str],
pages: list[str],
recurse_depth: int,
connector_name: str,
language_code: str = "en",
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
@@ -116,8 +118,10 @@ class MediaWikiConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
# short names can only have ascii letters and digits
self.connector_name = connector_name
connector_name = "".join(ch for ch in connector_name if ch.isalnum())
self.family = family_class_dispatch(hostname, "Wikipedia Connector")()
self.family = family_class_dispatch(hostname, connector_name)()
self.site = pywikibot.Site(fam=self.family, code=language_code)
self.categories = [
pywikibot.Category(self.site, f"Category:{category.replace(' ', '_')}")
@@ -206,6 +210,7 @@ class MediaWikiConnector(LoadConnector, PollConnector):
if __name__ == "__main__":
HOSTNAME = "fallout.fandom.com"
test_connector = MediaWikiConnector(
connector_name="Fallout",
hostname=HOSTNAME,
categories=["Fallout:_New_Vegas_factions"],
pages=["Fallout: New Vegas"],

View File

@@ -114,9 +114,7 @@ class DocumentBase(BaseModel):
title: str | None = None
from_ingestion_api: bool = False
def get_title_for_document_index(
self,
) -> str | None:
def get_title_for_document_index(self) -> str | None:
# If title is explicitly empty, return a None here for embedding purposes
if self.title == "":
return None
@@ -166,36 +164,6 @@ class Document(DocumentBase):
)
class DocumentErrorSummary(BaseModel):
id: str
semantic_id: str
section_link: str | None
@classmethod
def from_document(cls, doc: Document) -> "DocumentErrorSummary":
section_link = doc.sections[0].link if len(doc.sections) > 0 else None
return cls(
id=doc.id, semantic_id=doc.semantic_identifier, section_link=section_link
)
@classmethod
def from_dict(cls, data: dict) -> "DocumentErrorSummary":
return cls(
id=str(data.get("id")),
semantic_id=str(data.get("semantic_id")),
section_link=str(data.get("section_link")),
)
def to_dict(self) -> dict[str, str | None]:
return {
"id": self.id,
"semantic_id": self.semantic_id,
"section_link": self.section_link,
}
class IndexAttemptMetadata(BaseModel):
batch_num: int | None = None
num_exceptions: int = 0
connector_id: int
credential_id: int

View File

@@ -68,13 +68,12 @@ def make_slack_api_call_paginated(
def make_slack_api_rate_limited(
call: Callable[..., SlackResponse], max_retries: int = 7
call: Callable[..., SlackResponse], max_retries: int = 3
) -> Callable[..., SlackResponse]:
"""Wraps calls to slack API so that they automatically handle rate limiting"""
@wraps(call)
def rate_limited_call(**kwargs: Any) -> SlackResponse:
last_exception = None
for _ in range(max_retries):
try:
# Make the API call
@@ -86,20 +85,14 @@ def make_slack_api_rate_limited(
return response
except SlackApiError as e:
last_exception = e
try:
error = e.response["error"]
except KeyError:
error = "unknown error"
if error == "ratelimited":
if e.response["error"] == "ratelimited":
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
retry_after = int(e.response.headers.get("Retry-After", 1))
logger.info(
f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
)
time.sleep(retry_after)
elif error in ["already_reacted", "no_reaction"]:
elif e.response["error"] in ["already_reacted", "no_reaction"]:
# The response isn't used for reactions, this is basically just a pass
return e.response
else:
@@ -107,11 +100,7 @@ def make_slack_api_rate_limited(
raise
# If the code reaches this point, all retries have been exhausted
msg = f"Max retries ({max_retries}) exceeded"
if last_exception:
raise Exception(msg) from last_exception
else:
raise Exception(msg)
raise Exception(f"Max retries ({max_retries}) exceeded")
return rate_limited_call

View File

@@ -15,7 +15,6 @@ from playwright.sync_api import BrowserContext
from playwright.sync_api import Playwright
from playwright.sync_api import sync_playwright
from requests_oauthlib import OAuth2Session # type:ignore
from urllib3.exceptions import MaxRetryError
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_ID
@@ -84,28 +83,8 @@ def check_internet_connection(url: str) -> None:
try:
response = requests.get(url, timeout=3)
response.raise_for_status()
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
error_msg = {
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "Not Found",
500: "Internal Server Error",
502: "Bad Gateway",
503: "Service Unavailable",
504: "Gateway Timeout",
}.get(status_code, "HTTP Error")
raise Exception(f"{error_msg} ({status_code}) for {url} - {e}")
except requests.exceptions.SSLError as e:
cause = (
e.args[0].reason
if isinstance(e.args, tuple) and isinstance(e.args[0], MaxRetryError)
else e.args
)
raise Exception(f"SSL error {str(cause)}")
except (requests.RequestException, ValueError) as e:
raise Exception(f"Unable to reach {url} - check your internet connection: {e}")
except (requests.RequestException, ValueError):
raise Exception(f"Unable to reach {url} - check your internet connection")
def is_valid_url(url: str) -> bool:

View File

@@ -15,6 +15,7 @@ class WikipediaConnector(wiki.MediaWikiConnector):
categories: list[str],
pages: list[str],
recurse_depth: int,
connector_name: str,
language_code: str = "en",
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
@@ -23,6 +24,7 @@ class WikipediaConnector(wiki.MediaWikiConnector):
categories=categories,
pages=pages,
recurse_depth=recurse_depth,
connector_name=connector_name,
language_code=language_code,
batch_size=batch_size,
)

View File

@@ -1,7 +1,5 @@
from typing import Any
import requests
from retry import retry
from zenpy import Zenpy # type: ignore
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
@@ -21,24 +19,12 @@ from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
def _article_to_document(article: Article, content_tags: dict[str, str]) -> Document:
def _article_to_document(article: Article) -> Document:
author = BasicExpertInfo(
display_name=article.author.name, email=article.author.email
)
update_time = time_str_to_utc(article.updated_at)
# build metadata
metadata: dict[str, str | list[str]] = {
"labels": [str(label) for label in article.label_names if label],
"content_tags": [
content_tags[tag_id]
for tag_id in article.content_tag_ids
if tag_id in content_tags
],
}
# remove empty values
metadata = {k: v for k, v in metadata.items() if v}
labels = [str(label) for label in article.label_names]
return Document(
id=f"article:{article.id}",
@@ -49,7 +35,7 @@ def _article_to_document(article: Article, content_tags: dict[str, str]) -> Docu
semantic_identifier=article.title,
doc_updated_at=update_time,
primary_owners=[author],
metadata=metadata,
metadata={"labels": labels} if labels else {},
)
@@ -62,42 +48,6 @@ class ZendeskConnector(LoadConnector, PollConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self.zendesk_client: Zenpy | None = None
self.content_tags: dict[str, str] = {}
@retry(tries=3, delay=2, backoff=2)
def _set_content_tags(
self, subdomain: str, email: str, token: str, page_size: int = 30
) -> None:
# Construct the base URL
base_url = f"https://{subdomain}.zendesk.com/api/v2/guide/content_tags"
# Set up authentication
auth = (f"{email}/token", token)
# Set up pagination parameters
params = {"page[size]": page_size}
try:
while True:
# Make the GET request
response = requests.get(base_url, auth=auth, params=params)
# Check if the request was successful
if response.status_code == 200:
data = response.json()
content_tag_list = data.get("records", [])
for tag in content_tag_list:
self.content_tags[tag["id"]] = tag["name"]
# Check if there are more pages
if data.get("meta", {}).get("has_more", False):
params["page[after]"] = data["meta"]["after_cursor"]
else:
break
else:
raise Exception(f"Error: {response.status_code}\n{response.text}")
except Exception as e:
raise Exception(f"Error fetching content tags: {str(e)}")
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# Subdomain is actually the whole URL
@@ -112,11 +62,6 @@ class ZendeskConnector(LoadConnector, PollConnector):
email=credentials["zendesk_email"],
token=credentials["zendesk_token"],
)
self._set_content_tags(
subdomain,
credentials["zendesk_email"],
credentials["zendesk_token"],
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
@@ -147,30 +92,10 @@ class ZendeskConnector(LoadConnector, PollConnector):
):
continue
doc_batch.append(_article_to_document(article, self.content_tags))
doc_batch.append(_article_to_document(article))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch.clear()
if doc_batch:
yield doc_batch
if __name__ == "__main__":
import os
import time
connector = ZendeskConnector()
connector.load_credentials(
{
"zendesk_subdomain": os.environ["ZENDESK_SUBDOMAIN"],
"zendesk_email": os.environ["ZENDESK_EMAIL"],
"zendesk_token": os.environ["ZENDESK_TOKEN"],
}
)
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
document_batches = connector.poll_source(one_day_ago, current)
print(next(document_batches))

View File

@@ -3,7 +3,6 @@ from typing import List
from typing import Optional
from pydantic import BaseModel
from pydantic import Field
class Message(BaseModel):
@@ -19,11 +18,11 @@ class Message(BaseModel):
sender_realm_str: str
subject: str
topic_links: Optional[List[Any]] = None
last_edit_timestamp: Optional[int]
edit_history: Any = None
last_edit_timestamp: Optional[int] = None
edit_history: Any
reactions: List[Any]
submessages: List[Any]
flags: List[str] = Field(default_factory=list)
flags: List[str] = []
display_recipient: Optional[str] = None
type: Optional[str] = None
stream_id: int
@@ -40,4 +39,4 @@ class GetMessagesResponse(BaseModel):
found_newest: Optional[bool] = None
history_limited: Optional[bool] = None
anchor: Optional[str] = None
messages: List[Message] = Field(default_factory=list)
messages: List[Message] = []

Some files were not shown because too many files have changed in this diff Show More