Compare commits

..

3 Commits

Author SHA1 Message Date
pablodanswer
c29beaf403 combobox 2024-11-14 16:26:41 -08:00
pablodanswer
46f84d15f8 content scroll differences 2024-11-14 16:26:41 -08:00
pablodanswer
e8c93199f2 minor dropdown fix 2024-11-14 16:26:41 -08:00
443 changed files with 6966 additions and 17208 deletions

View File

@@ -65,7 +65,6 @@ jobs:
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_GTM_ENABLED=true
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -13,10 +13,7 @@ on:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
@@ -198,9 +195,6 @@ jobs:
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test \
/app/tests/integration/tests \

View File

@@ -1,225 +0,0 @@
name: Run Chromatic Tests
concurrency:
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on: push
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
jobs:
playwright-tests:
name: Playwright Tests
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Install playwright browsers
working-directory: ./web
run: npx playwright install --with-deps
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Web Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-web-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: danswer/danswer-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f danswer-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run pytest playwright test init
working-directory: ./backend
env:
PYTEST_IGNORE_SKIP: true
run: pytest -s tests/integration/tests/playwright/test_playwright.py
- name: Run Playwright tests
working-directory: ./web
run: npx playwright test
- uses: actions/upload-artifact@v4
if: always()
with:
# Chromatic automatically defaults to the test-results directory.
# Replace with the path to your custom directory and adjust the CHROMATIC_ARCHIVE_LOCATION environment variable accordingly.
name: test-results
path: ./web/test-results
retention-days: 30
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
chromatic-tests:
name: Chromatic Tests
needs: playwright-tests
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Download Playwright test results
uses: actions/download-artifact@v4
with:
name: test-results
path: ./web/test-results
- name: Run Chromatic
uses: chromaui/action@latest
with:
playwright: true
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
workingDir: ./web
env:
CHROMATIC_ARCHIVE_LOCATION: ./test-results

View File

@@ -20,12 +20,9 @@ env:
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
# Google
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1 }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
# Slab
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
jobs:
connectors-check:

1
.gitignore vendored
View File

@@ -7,4 +7,3 @@
.vscode/
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
/web/test-results/

View File

@@ -32,7 +32,7 @@ To contribute to this project, please follow the
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
See the [Formatting and Linting](#-formatting-and-linting) section for how to run these checks locally.
### Getting Help 🙋

View File

@@ -12,7 +12,7 @@
<a href="https://docs.danswer.dev/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
@@ -135,7 +135,7 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
## ✨Contributors
<a href="https://github.com/danswer-ai/danswer/graphs/contributors">
<a href="https://github.com/aryn-ai/sycamore/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
</a>

View File

@@ -73,7 +73,6 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"

View File

@@ -1,59 +0,0 @@
"""display custom llm models
Revision ID: 177de57c21c9
Revises: 4ee1287bd26a
Create Date: 2024-11-21 11:49:04.488677
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy import and_
revision = "177de57c21c9"
down_revision = "4ee1287bd26a"
branch_labels = None
depends_on = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
llm_provider = sa.table(
"llm_provider",
sa.column("id", sa.Integer),
sa.column("provider", sa.String),
sa.column("model_names", postgresql.ARRAY(sa.String)),
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
)
excluded_providers = ["openai", "bedrock", "anthropic", "azure"]
providers_to_update = sa.select(
llm_provider.c.id,
llm_provider.c.model_names,
llm_provider.c.display_model_names,
).where(
and_(
~llm_provider.c.provider.in_(excluded_providers),
llm_provider.c.model_names.isnot(None),
)
)
results = conn.execute(providers_to_update).fetchall()
for provider_id, model_names, display_model_names in results:
if display_model_names is None:
display_model_names = []
combined_model_names = list(set(display_model_names + model_names))
update_stmt = (
llm_provider.update()
.where(llm_provider.c.id == provider_id)
.values(display_model_names=combined_model_names)
)
conn.execute(update_stmt)
def downgrade() -> None:
pass

View File

@@ -1,45 +0,0 @@
"""add persona categories
Revision ID: 47e5bef3a1d7
Revises: dfbe9e93d3c7
Create Date: 2024-11-05 18:55:02.221064
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "47e5bef3a1d7"
down_revision = "dfbe9e93d3c7"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create the persona_category table
op.create_table(
"persona_category",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
# Add category_id to persona table
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"fk_persona_category",
"persona",
"persona_category",
["category_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
op.drop_column("persona", "category_id")
op.drop_table("persona_category")

View File

@@ -1,280 +0,0 @@
"""add_multiple_slack_bot_support
Revision ID: 4ee1287bd26a
Revises: 47e5bef3a1d7
Create Date: 2024-11-06 13:15:53.302644
"""
import logging
from typing import cast
from alembic import op
import sqlalchemy as sa
from sqlalchemy.orm import Session
from danswer.key_value_store.factory import get_kv_store
from danswer.db.models import SlackBot
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "4ee1287bd26a"
down_revision = "47e5bef3a1d7"
branch_labels: None = None
depends_on: None = None
# Configure logging
logger = logging.getLogger("alembic.runtime.migration")
logger.setLevel(logging.INFO)
def upgrade() -> None:
logger.info(f"{revision}: create_table: slack_bot")
# Create new slack_bot table
op.create_table(
"slack_bot",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("enabled", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("bot_token", sa.LargeBinary(), nullable=False),
sa.Column("app_token", sa.LargeBinary(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("bot_token"),
sa.UniqueConstraint("app_token"),
)
# # Create new slack_channel_config table
op.create_table(
"slack_channel_config",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("slack_bot_id", sa.Integer(), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
sa.Column("response_type", sa.String(), nullable=False),
sa.Column(
"enable_auto_filters", sa.Boolean(), nullable=False, server_default="false"
),
sa.ForeignKeyConstraint(
["slack_bot_id"],
["slack_bot.id"],
),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Handle existing Slack bot tokens first
logger.info(f"{revision}: Checking for existing Slack bot.")
bot_token = None
app_token = None
first_row_id = None
try:
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
except Exception:
logger.warning("No existing Slack bot tokens found.")
tokens = {}
bot_token = tokens.get("bot_token")
app_token = tokens.get("app_token")
if bot_token and app_token:
logger.info(f"{revision}: Found bot and app tokens.")
session = Session(bind=op.get_bind())
new_slack_bot = SlackBot(
name="Slack Bot (Migrated)",
enabled=True,
bot_token=bot_token,
app_token=app_token,
)
session.add(new_slack_bot)
session.commit()
first_row_id = new_slack_bot.id
# Create a default bot if none exists
# This is in case there are no slack tokens but there are channels configured
op.execute(
sa.text(
"""
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
SELECT 'Default Bot', true, '', ''
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
RETURNING id;
"""
)
)
# Get the bot ID to use (either from existing migration or newly created)
bot_id_query = sa.text(
"""
SELECT COALESCE(
:first_row_id,
(SELECT id FROM slack_bot ORDER BY id ASC LIMIT 1)
) as bot_id;
"""
)
result = op.get_bind().execute(bot_id_query, {"first_row_id": first_row_id})
bot_id = result.scalar()
# CTE (Common Table Expression) that transforms the old slack_bot_config table data
# This splits up the channel_names into their own rows
channel_names_cte = """
WITH channel_names AS (
SELECT
sbc.id as config_id,
sbc.persona_id,
sbc.response_type,
sbc.enable_auto_filters,
jsonb_array_elements_text(sbc.channel_config->'channel_names') as channel_name,
sbc.channel_config->>'respond_tag_only' as respond_tag_only,
sbc.channel_config->>'respond_to_bots' as respond_to_bots,
sbc.channel_config->'respond_member_group_list' as respond_member_group_list,
sbc.channel_config->'answer_filters' as answer_filters,
sbc.channel_config->'follow_up_tags' as follow_up_tags
FROM slack_bot_config sbc
)
"""
# Insert the channel names into the new slack_channel_config table
insert_statement = """
INSERT INTO slack_channel_config (
slack_bot_id,
persona_id,
channel_config,
response_type,
enable_auto_filters
)
SELECT
:bot_id,
channel_name.persona_id,
jsonb_build_object(
'channel_name', channel_name.channel_name,
'respond_tag_only',
COALESCE((channel_name.respond_tag_only)::boolean, false),
'respond_to_bots',
COALESCE((channel_name.respond_to_bots)::boolean, false),
'respond_member_group_list',
COALESCE(channel_name.respond_member_group_list, '[]'::jsonb),
'answer_filters',
COALESCE(channel_name.answer_filters, '[]'::jsonb),
'follow_up_tags',
COALESCE(channel_name.follow_up_tags, '[]'::jsonb)
),
channel_name.response_type,
channel_name.enable_auto_filters
FROM channel_names channel_name;
"""
op.execute(sa.text(channel_names_cte + insert_statement).bindparams(bot_id=bot_id))
# Clean up old tokens if they existed
try:
if bot_token and app_token:
logger.info(f"{revision}: Removing old bot and app tokens.")
get_kv_store().delete("slack_bot_tokens_config_key")
except Exception:
logger.warning("tried to delete tokens in dynamic config but failed")
# Rename the table
op.rename_table(
"slack_bot_config__standard_answer_category",
"slack_channel_config__standard_answer_category",
)
# Rename the column
op.alter_column(
"slack_channel_config__standard_answer_category",
"slack_bot_config_id",
new_column_name="slack_channel_config_id",
)
# Drop the table with CASCADE to handle dependent objects
op.execute("DROP TABLE slack_bot_config CASCADE")
logger.info(f"{revision}: Migration complete.")
def downgrade() -> None:
# Recreate the old slack_bot_config table
op.create_table(
"slack_bot_config",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
sa.Column("response_type", sa.String(), nullable=False),
sa.Column("enable_auto_filters", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Migrate data back to the old format
# Group by persona_id to combine channel names back into arrays
op.execute(
sa.text(
"""
INSERT INTO slack_bot_config (
persona_id,
channel_config,
response_type,
enable_auto_filters
)
SELECT DISTINCT ON (persona_id)
persona_id,
jsonb_build_object(
'channel_names', (
SELECT jsonb_agg(c.channel_config->>'channel_name')
FROM slack_channel_config c
WHERE c.persona_id = scc.persona_id
),
'respond_tag_only', (channel_config->>'respond_tag_only')::boolean,
'respond_to_bots', (channel_config->>'respond_to_bots')::boolean,
'respond_member_group_list', channel_config->'respond_member_group_list',
'answer_filters', channel_config->'answer_filters',
'follow_up_tags', channel_config->'follow_up_tags'
),
response_type,
enable_auto_filters
FROM slack_channel_config scc
WHERE persona_id IS NOT NULL;
"""
)
)
# Rename the table back
op.rename_table(
"slack_channel_config__standard_answer_category",
"slack_bot_config__standard_answer_category",
)
# Rename the column back
op.alter_column(
"slack_bot_config__standard_answer_category",
"slack_channel_config_id",
new_column_name="slack_bot_config_id",
)
# Try to save the first bot's tokens back to KV store
try:
first_bot = (
op.get_bind()
.execute(
sa.text(
"SELECT bot_token, app_token FROM slack_bot ORDER BY id LIMIT 1"
)
)
.first()
)
if first_bot and first_bot.bot_token and first_bot.app_token:
tokens = {
"bot_token": first_bot.bot_token,
"app_token": first_bot.app_token,
}
get_kv_store().store("slack_bot_tokens_config_key", tokens)
except Exception:
logger.warning("Failed to save tokens back to KV store")
# Drop the new tables in reverse order
op.drop_table("slack_channel_config")
op.drop_table("slack_bot")

View File

@@ -1,45 +0,0 @@
"""remove default bot
Revision ID: 6d562f86c78b
Revises: 177de57c21c9
Create Date: 2024-11-22 11:51:29.331336
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6d562f86c78b"
down_revision = "177de57c21c9"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
sa.text(
"""
DELETE FROM slack_bot
WHERE name = 'Default Bot'
AND bot_token = ''
AND app_token = ''
AND NOT EXISTS (
SELECT 1 FROM slack_channel_config
WHERE slack_channel_config.slack_bot_id = slack_bot.id
)
"""
)
)
def downgrade() -> None:
op.execute(
sa.text(
"""
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
SELECT 'Default Bot', true, '', ''
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
RETURNING id;
"""
)
)

View File

@@ -9,8 +9,8 @@ from alembic import op
import sqlalchemy as sa
from danswer.db.models import IndexModelStatus
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"

View File

@@ -1,35 +0,0 @@
"""add web ui option to slack config
Revision ID: 93560ba1b118
Revises: 6d562f86c78b
Create Date: 2024-11-24 06:36:17.490612
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "93560ba1b118"
down_revision = "6d562f86c78b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add show_continue_in_web_ui with default False to all existing channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb
WHERE NOT channel_config ? 'show_continue_in_web_ui'
"""
)
def downgrade() -> None:
# Remove show_continue_in_web_ui from all channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config - 'show_continue_in_web_ui'
"""
)

View File

@@ -7,7 +7,6 @@ Create Date: 2024-10-26 13:06:06.937969
"""
from alembic import op
from sqlalchemy.orm import Session
from sqlalchemy import text
# Import your models and constants
from danswer.db.models import (
@@ -16,6 +15,7 @@ from danswer.db.models import (
Credential,
IndexAttempt,
)
from danswer.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
@@ -30,11 +30,13 @@ def upgrade() -> None:
bind = op.get_bind()
session = Session(bind=bind)
# Get connectors using raw SQL
result = bind.execute(
text("SELECT id FROM connector WHERE source = 'requesttracker'")
connectors_to_delete = (
session.query(Connector)
.filter(Connector.source == DocumentSource.REQUESTTRACKER)
.all()
)
connector_ids = [row[0] for row in result]
connector_ids = [connector.id for connector in connectors_to_delete]
if connector_ids:
cc_pairs_to_delete = (

View File

@@ -1,7 +1,7 @@
"""add creator to cc pair
Revision ID: 9cf5c00f72fe
Revises: 26b931506ecb
Revises: c0fd6e4da83a
Create Date: 2024-11-12 15:16:42.682902
"""

View File

@@ -1,27 +0,0 @@
"""add auto scroll to user model
Revision ID: a8c2065484e6
Revises: abe7378b8217
Create Date: 2024-11-22 17:34:09.690295
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a8c2065484e6"
down_revision = "abe7378b8217"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None),
)
def downgrade() -> None:
op.drop_column("user", "auto_scroll")

View File

@@ -1,30 +0,0 @@
"""add indexing trigger to cc_pair
Revision ID: abe7378b8217
Revises: 6d562f86c78b
Create Date: 2024-11-26 19:09:53.481171
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "abe7378b8217"
down_revision = "93560ba1b118"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"indexing_trigger",
sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "indexing_trigger")

View File

@@ -1,42 +0,0 @@
"""extended_role_for_non_web
Revision ID: dfbe9e93d3c7
Revises: 9cf5c00f72fe
Create Date: 2024-11-16 07:54:18.727906
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "dfbe9e93d3c7"
down_revision = "9cf5c00f72fe"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE "user"
SET role = 'EXT_PERM_USER'
WHERE has_web_login = false
"""
)
op.drop_column("user", "has_web_login")
def downgrade() -> None:
op.add_column(
"user",
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
)
op.execute(
"""
UPDATE "user"
SET has_web_login = false,
role = 'BASIC'
WHERE role IN ('SLACK_USER', 'EXT_PERM_USER')
"""
)

View File

@@ -2,8 +2,8 @@ from typing import cast
from danswer.configs.constants import KV_USER_STORE_KEY
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.utils.special_types import JSON_ro
def get_invited_users() -> list[str]:

View File

@@ -23,9 +23,7 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
)
return UserPreferences(**preferences_data)
except KvKeyNotFoundError:
return UserPreferences(
chosen_assistants=None, default_model=None, auto_scroll=True
)
return UserPreferences(chosen_assistants=None, default_model=None)
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:

View File

@@ -13,9 +13,6 @@ class UserRole(str, Enum):
groups they are curators of
- Global Curator can perform admin actions
for all groups they are a member of
- Limited can access a limited set of basic api endpoints
- Slack are users that have used danswer via slack but dont have a web login
- External permissioned users that have been picked up during the external permissions sync process but don't have a web login
"""
LIMITED = "limited"
@@ -23,14 +20,6 @@ class UserRole(str, Enum):
ADMIN = "admin"
CURATOR = "curator"
GLOBAL_CURATOR = "global_curator"
SLACK_USER = "slack_user"
EXT_PERM_USER = "ext_perm_user"
def is_web_login(self) -> bool:
return self not in [
UserRole.SLACK_USER,
UserRole.EXT_PERM_USER,
]
class UserStatus(str, Enum):
@@ -45,8 +34,10 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
has_web_login: bool | None = True
tenant_id: str | None = None
class UserUpdate(schemas.BaseUserUpdate):
role: UserRole
has_web_login: bool | None = True

View File

@@ -49,7 +49,8 @@ from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
@@ -80,8 +81,8 @@ from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
@@ -221,8 +222,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
async def create(
self,
user_create: schemas.UC | UserCreate,
@@ -248,9 +247,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
)
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db
@@ -269,9 +266,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.role.is_web_login() and user_create.role.is_web_login():
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
@@ -285,7 +287,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
return user
async def oauth_callback(
self,
self: "BaseUserManager[models.UOAP, models.ID]",
oauth_name: str,
access_token: str,
account_id: str,
@@ -296,7 +298,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
*,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> User:
) -> models.UOAP:
referral_source = None
if request:
referral_source = getattr(request.state, "referral_source", None)
@@ -322,11 +324,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_domain(account_email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
)
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db
self.database = tenant_user_db # type: ignore
oauth_account_dict = {
"oauth_name": oauth_name,
@@ -378,11 +378,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
user, existing_oauth_account, oauth_account_dict
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
@@ -395,15 +391,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.role.is_web_login():
if not user.has_web_login: # type: ignore
await self.user_db.update(
user,
{
"is_verified": is_verified_by_default,
"role": UserRole.BASIC,
"has_web_login": True,
},
)
user.is_verified = is_verified_by_default
user.has_web_login = True # type: ignore
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
@@ -478,7 +475,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
self.password_helper.hash(credentials.password)
return None
if not user.role.is_web_login():
has_web_login = attributes.get_attribute(user, "has_web_login")
if not has_web_login:
raise BasicAuthenticationError(
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
@@ -609,7 +608,7 @@ optional_fastapi_current_user = fastapi_users.current_user(active=True, optional
async def optional_user_(
request: Request,
user: User | None,
async_db_session: AsyncSession,
db_session: Session,
) -> User | None:
"""NOTE: `request` and `db_session` are not used here, but are included
for the EE version of this function."""
@@ -618,21 +617,13 @@ async def optional_user_(
async def optional_user(
request: Request,
async_db_session: AsyncSession = Depends(get_async_session),
db_session: Session = Depends(get_session),
user: User | None = Depends(optional_fastapi_current_user),
) -> User | None:
versioned_fetch_user = fetch_versioned_implementation(
"danswer.auth.users", "optional_user_"
)
user = await versioned_fetch_user(request, user, async_db_session)
# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
return user
return await versioned_fetch_user(request, user, db_session)
async def double_check_user(
@@ -918,8 +909,8 @@ def get_oauth_router(
return router
async def api_key_dep(
request: Request, async_db_session: AsyncSession = Depends(get_async_session)
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
@@ -929,7 +920,7 @@ async def api_key_dep(
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
user = fetch_user_for_api_key(hashed_api_key, db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")

View File

@@ -1,6 +1,5 @@
import multiprocessing
from typing import Any
from typing import cast
from celery import bootsteps # type: ignore
from celery import Celery
@@ -15,16 +14,10 @@ from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.tasks.indexing.tasks import (
get_unfenced_index_attempt_ids,
)
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.engine import SqlEngine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
@@ -98,15 +91,6 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
role: str = cast(str, info.get("role"))
connected_slaves: int = info.get("connected_slaves", 0)
logger.info(
f"Redis INFO REPLICATION: role={role} connected_slaves={connected_slaves}"
)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
@@ -156,23 +140,6 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
RedisConnectorExternalGroupSync.reset_all(r)
# mark orphaned index attempts as failed
with get_session_with_default_tenant() as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Canceling leftover index attempt found on startup: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:

View File

@@ -4,6 +4,7 @@ from typing import Any
from sqlalchemy.orm import Session
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
@@ -16,7 +17,6 @@ from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.redis.redis_connector import RedisConnector
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
@@ -78,7 +78,7 @@ def document_batch_to_ids(
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
callback: IndexingHeartbeatInterface | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> set[str]:
"""
If the SlimConnector hasnt been implemented for the given connector, just pull
@@ -111,15 +111,10 @@ def extract_ids_from_runnable_connector(
for doc_batch in doc_batch_generator:
if callback:
if callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
raise RuntimeError("Stop signal received")
callback.progress(len(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
if callback:
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
return all_connector_doc_ids

View File

@@ -1,11 +1,12 @@
from datetime import datetime
from datetime import timezone
import redis
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
@@ -18,7 +19,7 @@ from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.search_settings import get_all_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_delete import RedisConnectorDeletePayload
from danswer.redis.redis_connector_delete import RedisConnectorDeletionFenceData
from danswer.redis.redis_pool import get_redis_client
@@ -36,7 +37,7 @@ class TaskDependencyError(RuntimeError):
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -59,7 +60,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, lock_beat, tenant_id
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
@@ -85,7 +86,8 @@ def try_generate_document_cc_pair_cleanup_tasks(
app: Celery,
cc_pair_id: int,
db_session: Session,
lock_beat: RedisLock,
r: Redis,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
@@ -116,7 +118,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
fence_payload = RedisConnectorDeletePayload(
fence_payload = RedisConnectorDeletionFenceData(
num_tasks=None,
submitted=datetime.now(timezone.utc),
)

View File

@@ -8,7 +8,6 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.access.models import DocExternalAccess
from danswer.background.celery.apps.app_base import task_logger
@@ -25,10 +24,10 @@ from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
from danswer.db.users import batch_add_non_web_user_if_not_exists
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncPayload,
RedisConnectorPermissionSyncData,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import doc_permission_sync_ctx
@@ -139,7 +138,7 @@ def try_creating_permissions_sync_task(
LOCK_TIMEOUT = 30
lock: RedisLock = r.lock(
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
)
@@ -163,7 +162,7 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
result = app.send_task(
app.send_task(
"connector_permission_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
@@ -175,8 +174,8 @@ def try_creating_permissions_sync_task(
)
# set a basic fence to start
payload = RedisConnectorPermissionSyncPayload(
started=None, celery_task_id=result.id
payload = RedisConnectorPermissionSyncData(
started=None,
)
redis_connector.permissions.set_fence(payload)
@@ -242,17 +241,13 @@ def connector_permission_sync_generator_task(
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
raise ValueError(
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
)
raise ValueError(f"No doc sync func found for {source_type}")
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
logger.info(f"Syncing docs for {source_type}")
payload = redis_connector.permissions.payload
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
payload.started = datetime.now(timezone.utc)
payload = RedisConnectorPermissionSyncData(
started=datetime.now(timezone.utc),
)
redis_connector.permissions.set_fence(payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
@@ -306,7 +301,7 @@ def update_external_document_permissions_task(
try:
with get_session_with_tenant(tenant_id) as db_session:
# Then we build the update requests to update vespa
batch_add_ext_perm_user_if_not_exists(
batch_add_non_web_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),
)

View File

@@ -8,7 +8,6 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
@@ -25,15 +24,12 @@ from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIOD
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
logger = setup_logger()
@@ -53,7 +49,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
if cc_pair.access_type != AccessType.SYNC:
return False
# skip external group sync if not active
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
@@ -70,9 +66,9 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
if last_ext_group_sync is None:
return True
source_sync_period = EXTERNAL_GROUP_SYNC_PERIODS.get(cc_pair.connector.source)
source_sync_period = EXTERNAL_GROUP_SYNC_PERIOD
# If EXTERNAL_GROUP_SYNC_PERIODS is None, we always run the sync.
# If EXTERNAL_GROUP_SYNC_PERIOD is None, we always run the sync.
if not source_sync_period:
return True
@@ -111,7 +107,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_external_group_sync_task(
tasks_created = try_creating_permissions_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
@@ -129,7 +125,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()
def try_creating_external_group_sync_task(
def try_creating_permissions_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
@@ -160,7 +156,7 @@ def try_creating_external_group_sync_task(
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
result = app.send_task(
_ = app.send_task(
"connector_external_group_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
@@ -170,13 +166,8 @@ def try_creating_external_group_sync_task(
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
payload = RedisConnectorExternalGroupSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=result.id,
)
redis_connector.external_group_sync.set_fence(payload)
# set a basic fence to start
redis_connector.external_group_sync.set_fence(True)
except Exception:
task_logger.exception(
@@ -204,7 +195,7 @@ def connector_external_group_sync_generator_task(
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles external group syncing for a given connector credential pair
Permission sync task that handles document permission syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
@@ -212,7 +203,7 @@ def connector_external_group_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = r.lock(
lock = r.lock(
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
@@ -237,13 +228,9 @@ def connector_external_group_sync_generator_task(
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None:
raise ValueError(
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
)
raise ValueError(f"No external group sync func found for {source_type}")
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
logger.info(f"Syncing docs for {source_type}")
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
@@ -262,6 +249,7 @@ def connector_external_group_sync_generator_task(
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
@@ -272,6 +260,6 @@ def connector_external_group_sync_generator_task(
raise e
finally:
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
redis_connector.external_group_sync.set_fence(None)
redis_connector.external_group_sync.set_fence(False)
if lock.owned():
lock.release()

View File

@@ -10,13 +10,12 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -25,34 +24,27 @@ from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector import mark_ccpair_with_indexing_trigger
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingMode
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import delete_index_attempt
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_index import RedisConnectorIndexPayload
from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
@@ -64,97 +56,30 @@ from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
class IndexingCallback(IndexingHeartbeatInterface):
class RunIndexingCallback(RunIndexingCallbackInterface):
def __init__(
self,
stop_key: str,
generator_progress_key: str,
redis_lock: RedisLock,
redis_lock: redis.lock.Lock,
redis_client: Redis,
):
super().__init__()
self.redis_lock: RedisLock = redis_lock
self.redis_lock: redis.lock.Lock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = "IndexingCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
return True
return False
def progress(self, tag: str, amount: int) -> None:
try:
self.redis_lock.reacquire()
self.last_tag = tag
self.last_lock_reacquire = datetime.now(timezone.utc)
except LockError:
logger.exception(
f"IndexingCallback - lock.reacquire exceptioned. "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
raise
def progress(self, amount: int) -> None:
self.redis_lock.reacquire()
self.redis_client.incrby(self.generator_progress_key, amount)
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down
# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue
# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.
# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue
if attempt.status != attempt_2.status:
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
@shared_task(
name="check_for_indexing",
soft_time_limit=300,
@@ -162,10 +87,10 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
)
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
tasks_created = 0
locked = False
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
lock_beat = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -175,9 +100,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
if not lock_beat.acquire(blocking=False):
return None
locked = True
# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
@@ -196,24 +118,26 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
embedding_model=embedding_model,
)
# gather cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
# kick off index attempts
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
for search_settings_instance in search_settings_list:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
for search_settings_instance in search_settings:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
@@ -229,80 +153,33 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
if not _should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
search_settings_primary=search_settings_primary,
secondary_index_building=len(search_settings_list) > 1,
secondary_index_building=len(search_settings) > 1,
db_session=db_session,
):
continue
reindex = False
if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
reindex = True
task_logger.info(
f"Connector indexing manual trigger detected: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"indexing_mode={cc_pair.indexing_trigger}"
)
mark_ccpair_with_indexing_trigger(
cc_pair.id, None, db_session
)
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
self.app,
cc_pair,
search_settings_instance,
reindex,
False,
db_session,
r,
tenant_id,
)
if attempt_id:
task_logger.info(
f"Connector indexing queued: "
f"index_attempt={attempt_id} "
f"Indexing queued: index_attempt={attempt_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id}"
f"search_settings={search_settings_instance.id} "
)
tasks_created += 1
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
lock_beat.reacquire()
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Unfenced index attempt found in DB: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.error(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -310,14 +187,8 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if locked:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
if lock_beat.owned():
lock_beat.release()
return tasks_created
@@ -326,7 +197,6 @@ def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
@@ -391,11 +261,6 @@ def _should_index(
):
return False
if search_settings_primary:
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index:
return True
@@ -428,11 +293,10 @@ def try_creating_indexing_task(
"""
LOCK_TIMEOUT = 30
index_attempt_id: int | None = None
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock: RedisLock = r.lock(
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
@@ -461,7 +325,7 @@ def try_creating_indexing_task(
redis_connector_index.generator_clear()
# set a basic fence to start
payload = RedisConnectorIndexPayload(
payload = RedisConnectorIndexingFenceData(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
@@ -483,8 +347,6 @@ def try_creating_indexing_task(
custom_task_id = redis_connector_index.generate_generator_task_id()
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
"connector_indexing_proxy_task",
kwargs=dict(
@@ -504,17 +366,15 @@ def try_creating_indexing_task(
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
redis_connector_index.set_fence(payload)
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
f"Unexpected exception: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
if index_attempt_id is not None:
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
return None
finally:
if lock.owned():
@@ -523,11 +383,8 @@ def try_creating_indexing_task(
return index_attempt_id
@shared_task(
name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True
)
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
def connector_indexing_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
@@ -535,19 +392,15 @@ def connector_indexing_proxy_task(
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"Indexing proxy - starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if not self.request.id:
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
job = client.submit(
connector_indexing_task_wrapper,
connector_indexing_task,
index_attempt_id,
cc_pair_id,
search_settings_id,
@@ -558,7 +411,7 @@ def connector_indexing_proxy_task(
if not job:
task_logger.info(
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -566,36 +419,14 @@ def connector_indexing_proxy_task(
return
task_logger.info(
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
while True:
sleep(5)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing proxy - termination signal detected: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
job.cancel()
break
sleep(10)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
@@ -612,7 +443,7 @@ def connector_indexing_proxy_task(
if job.status == "error":
task_logger.error(
f"Indexing watchdog - spawned task exceptioned: "
f"Indexing proxy - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
@@ -624,7 +455,7 @@ def connector_indexing_proxy_task(
break
task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"Indexing proxy - finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -632,38 +463,6 @@ def connector_indexing_proxy_task(
return
def connector_indexing_task_wrapper(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Just wraps connector_indexing_task so we can log any exceptions before
re-raising it."""
result: int | None = None
try:
result = connector_indexing_task(
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
is_ee,
)
except:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise
return result
def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
@@ -700,8 +499,7 @@ def connector_indexing_task(
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
logger.info(
f"Indexing spawned task starting: "
f"attempt={index_attempt_id} "
f"Indexing spawned task starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -718,7 +516,6 @@ def connector_indexing_task(
if redis_connector.delete.fenced:
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}"
)
@@ -726,18 +523,18 @@ def connector_indexing_task(
if redis_connector.stop.fenced:
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}"
)
while True:
if not redis_connector_index.fenced: # The fence must exist
# wait for the fence to come up
if not redis_connector_index.fenced:
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
)
payload = redis_connector_index.payload # The payload must exist
payload = redis_connector_index.payload
if not payload:
raise ValueError("connector_indexing_task: payload invalid or not found")
@@ -760,7 +557,7 @@ def connector_indexing_task(
)
break
lock: RedisLock = r.lock(
lock = r.lock(
redis_connector_index.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
)
@@ -769,7 +566,7 @@ def connector_indexing_task(
if not acquired:
logger.warning(
f"Indexing task already running, exiting...: "
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
return None
@@ -804,7 +601,7 @@ def connector_indexing_task(
)
# define a callback class
callback = IndexingCallback(
callback = RunIndexingCallback(
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,

View File

@@ -12,7 +12,7 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.tasks.indexing.tasks import IndexingCallback
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
@@ -39,14 +39,7 @@ logger = setup_logger()
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if pruning is due.
Next pruning time is calculated as a delta from the last successful prune, or the
last successful indexing if pruning has never succeeded.
TODO(rkuo): consider whether we should allow pruning to be immediately rescheduled
if pruning fails (which is what it does now). A backoff could be reasonable.
"""
"""Returns boolean indicating if pruning is due."""
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
@@ -232,8 +225,6 @@ def connector_pruning_generator_task(
pruning_ctx_dict["request_id"] = self.request.id
pruning_ctx.set(pruning_ctx_dict)
task_logger.info(f"Pruning generator starting: cc_pair={cc_pair_id}")
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
@@ -264,11 +255,6 @@ def connector_pruning_generator_task(
)
return
task_logger.info(
f"Pruning generator running connector: "
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source}"
)
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
@@ -277,13 +263,12 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
callback = IndexingCallback(
callback = RunIndexingCallback(
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,
r,
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector, callback
@@ -305,8 +290,8 @@ def connector_pruning_generator_task(
task_logger.info(
f"Pruning set collected: "
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source} "
f"docs_to_remove={len(doc_ids_to_remove)}"
f"docs_to_remove={len(doc_ids_to_remove)} "
f"doc_source={cc_pair.connector.source}"
)
task_logger.info(
@@ -329,10 +314,10 @@ def connector_pruning_generator_task(
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
)
redis_connector.prune.reset()
redis_connector.prune.generator_clear()
redis_connector.prune.taskset_clear()
redis_connector.prune.set_fence(False)
raise e
finally:
if lock.owned():
lock.release()
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")

View File

@@ -177,17 +177,7 @@ def document_by_cc_pair_cleanup_task(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"tenant={tenant_id} doc={document_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
with get_session_with_tenant(tenant_id):
mark_document_as_modified(document_id, db_session)
return False

View File

@@ -5,6 +5,7 @@ from http import HTTPStatus
from typing import cast
import httpx
import redis
from celery import Celery
from celery import shared_task
from celery import Task
@@ -12,7 +13,6 @@ from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from tenacity import RetryError
@@ -48,9 +48,11 @@ from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import DocumentSet
from danswer.db.models import IndexAttempt
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
@@ -59,7 +61,7 @@ from danswer.redis.redis_connector_credential_pair import RedisConnectorCredenti
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncPayload,
RedisConnectorPermissionSyncData,
)
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
@@ -165,7 +167,7 @@ def try_generate_stale_document_sync_tasks(
celery_app: Celery,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
# the fence is up, do nothing
@@ -183,12 +185,7 @@ def try_generate_stale_document_sync_tasks(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
task_logger.info(
"RedisConnector.generate_tasks starting by cc_pair. "
"Documents spanning multiple cc_pairs will only be synced once."
)
docs_to_skip: set[str] = set()
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
@@ -196,21 +193,22 @@ def try_generate_stale_document_sync_tasks(
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
rc.set_skip_docs(docs_to_skip)
result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
tasks_generated = rc.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if result is None:
if tasks_generated is None:
continue
if result[1] == 0:
if tasks_generated == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished for single cc_pair. "
f"cc_pair={cc_pair.id} tasks_generated={result[0]} tasks_possible={result[1]}"
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
total_tasks_generated += result[0]
total_tasks_generated += tasks_generated
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
@@ -225,7 +223,7 @@ def try_generate_document_set_sync_tasks(
document_set_id: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -253,11 +251,12 @@ def try_generate_document_set_sync_tasks(
)
# Add all documents that need to be updated into the queue
result = rds.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if result is None:
tasks_generated = rds.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
return None
tasks_generated = result[0]
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
@@ -266,7 +265,7 @@ def try_generate_document_set_sync_tasks(
task_logger.info(
f"RedisDocumentSet.generate_tasks finished. "
f"document_set={document_set.id} tasks_generated={tasks_generated}"
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
@@ -279,7 +278,7 @@ def try_generate_user_group_sync_tasks(
usergroup_id: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
lock_beat: redis.lock.Lock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -308,11 +307,12 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
result = rug.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if result is None:
tasks_generated = rug.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
return None
tasks_generated = result[0]
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
@@ -321,7 +321,7 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks finished. "
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
@@ -441,22 +441,11 @@ def monitor_connector_deletion_taskset(
db_session, cc_pair.connector_id, cc_pair.credential_id
)
if len(doc_ids) > 0:
# NOTE(rkuo): if this happens, documents somehow got added while
# deletion was in progress. Likely a bug gating off pruning and indexing
# work before deletion starts.
# if this happens, documents somehow got added while deletion was in progress. Likely a bug
# gating off pruning and indexing work before deletion starts
task_logger.warning(
"Connector deletion - documents still found after taskset completion. "
"Clearing the current deletion attempt and allowing deletion to restart: "
f"cc_pair={cc_pair_id} "
f"docs_deleted={fence_data.num_tasks} "
f"docs_remaining={len(doc_ids)}"
)
# We don't want to waive off why we get into this state, but resetting
# our attempt and letting the deletion restart is a good way to recover
redis_connector.delete.reset()
raise RuntimeError(
"Connector deletion - documents still found after taskset completion"
f"Connector deletion - documents still found after taskset completion: "
f"cc_pair={cc_pair_id} num={len(doc_ids)}"
)
# clean up the rest of the related Postgres entities
@@ -520,7 +509,8 @@ def monitor_connector_deletion_taskset(
f"docs_deleted={fence_data.num_tasks}"
)
redis_connector.delete.reset()
redis_connector.delete.taskset_clear()
redis_connector.delete.set_fence(None)
def monitor_ccpair_pruning_taskset(
@@ -589,7 +579,7 @@ def monitor_ccpair_permissions_taskset(
if remaining > 0:
return
payload: RedisConnectorPermissionSyncPayload | None = (
payload: RedisConnectorPermissionSyncData | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
@@ -597,7 +587,9 @@ def monitor_ccpair_permissions_taskset(
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
redis_connector.permissions.reset()
redis_connector.permissions.taskset_clear()
redis_connector.permissions.generator_clear()
redis_connector.permissions.set_fence(None)
def monitor_ccpair_indexing_taskset(
@@ -634,8 +626,8 @@ def monitor_ccpair_indexing_taskset(
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -644,58 +636,39 @@ def monitor_ccpair_indexing_taskset(
# the task is still setting up
return
# Read result state BEFORE generator_complete_key to avoid a race condition
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(payload.celery_task_id)
result_state = result.state
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = get_completion / generator_complete not signaled
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
if status_int is None: # inner signal not set ... possible error
result_state = result.state
if (
result_state in READY_STATES
): # outer signal in terminal state ... possible error
# Now double check!
if redis_connector_index.get_completion() is None:
# inner signal still not set (and cannot change when outer result_state is READY)
# Task is finished but generator complete isn't set.
# We have a problem! Worker may have crashed.
if status_int is None:
if result_state in READY_STATES:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
task_logger.info(
f"Connector indexing aborted: "
f"cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
msg = (
f"Connector indexing aborted or exceptioned: "
f"attempt={payload.index_attempt_id} "
f"celery_task={payload.celery_task_id} "
f"result_state={result_state} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason="Connector indexing aborted or exceptioned.",
)
task_logger.warning(msg)
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
redis_connector_index.reset()
redis_connector_index.reset()
return
status_enum = HTTPStatus(status_int)
task_logger.info(
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -716,7 +689,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
lock_beat: redis.lock.Lock = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -728,7 +701,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_celery = celery_get_queue_length("celery", r)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)
@@ -754,6 +727,34 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
f"permissions_sync={n_permissions_sync} "
)
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
with get_session_with_tenant(tenant_id) as db_session:
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
if not r.exists(fence_key):
failure_reason = (
f"Unknown index attempt. Might be left over from a process restart: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.warning(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)

View File

@@ -1,8 +1,6 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.background.celery.apps.beat import celery_app
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app: Celery = celery_app
app = celery_app

View File

@@ -1,10 +1,8 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app: Celery = fetch_versioned_implementation(
app = fetch_versioned_implementation(
"danswer.background.celery.apps.primary", "celery_app"
)

View File

@@ -1,5 +1,7 @@
import time
import traceback
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -19,7 +21,6 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
@@ -30,7 +31,7 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import setup_logger
from danswer.utils.logger import TaskAttemptSingleton
@@ -41,6 +42,19 @@ logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
class RunIndexingCallbackInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
@abstractmethod
def progress(self, amount: int) -> None:
"""Send progress updates to the caller."""
def _get_connector_runner(
db_session: Session,
attempt: IndexAttempt,
@@ -88,15 +102,11 @@ def _get_connector_runner(
)
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None,
callback: IndexingHeartbeatInterface | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
@@ -128,7 +138,13 @@ def _run_indexing(
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings,
callback=callback,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
)
indexing_pipeline = build_indexing_pipeline(
@@ -141,7 +157,6 @@ def _run_indexing(
),
db_session=db_session,
tenant_id=tenant_id,
callback=callback,
)
db_cc_pair = index_attempt.connector_credential_pair
@@ -213,7 +228,7 @@ def _run_indexing(
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise ConnectorStopSignal("Connector stop signal detected")
raise RuntimeError("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
@@ -274,7 +289,7 @@ def _run_indexing(
db_session.commit()
if callback:
callback.progress("_run_indexing", len(doc_batch))
callback.progress(len(doc_batch))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
@@ -307,16 +322,26 @@ def _run_indexing(
)
except Exception as e:
logger.exception(
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
)
if isinstance(e, ConnectorStopSignal):
mark_attempt_canceled(
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt.id,
db_session,
reason=str(e),
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
@@ -328,37 +353,6 @@ def _run_indexing(
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
else:
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
@@ -425,7 +419,7 @@ def run_indexing_entrypoint(
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
callback: IndexingHeartbeatInterface | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> None:
try:
if is_ee:
@@ -439,13 +433,11 @@ def run_indexing_entrypoint(
with get_session_with_tenant(tenant_id) as db_session:
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing starting for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
@@ -453,8 +445,10 @@ def run_indexing_entrypoint(
_run_indexing(db_session, attempt, tenant_id, callback)
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing finished for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)

View File

@@ -7,10 +7,10 @@ from sqlalchemy.orm import Session
from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.context.search.models import InferenceSection
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.models import ChatMessage
from danswer.llm.answering.models import PreviousMessage
from danswer.search.models import InferenceSection
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -5,7 +5,6 @@ from danswer.configs.chat_configs import INPUT_PROMPT_YAML
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.context.search.enums import RecencyBiasSetting
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.models import DocumentSet as DocumentSetDBModel
@@ -15,6 +14,7 @@ from danswer.db.models import Tool as ToolDBModel
from danswer.db.persona import get_prompt_by_name
from danswer.db.persona import upsert_persona
from danswer.db.persona import upsert_prompt
from danswer.search.enums import RecencyBiasSetting
def load_prompts_from_yaml(
@@ -81,7 +81,6 @@ def load_personas_from_yaml(
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)

View File

@@ -6,10 +6,10 @@ from typing import Any
from pydantic import BaseModel
from danswer.configs.constants import DocumentSource
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import SearchType
from danswer.context.search.models import RetrievalDocs
from danswer.context.search.models import SearchResponse
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType

View File

@@ -5,7 +5,7 @@ personas:
# this is for DanswerBot to use when tagged in a non-configured channel
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
- id: 0
name: "Search"
name: "Knowledge"
description: >
Assistant with access to documents from your Connected Sources.
# Default Prompt objects attached to the persona, see prompts.yaml

View File

@@ -23,16 +23,6 @@ from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.context.search.enums import OptionalSearchSetting
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import SearchType
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RetrievalDetails
from danswer.context.search.retrieval.search_runner import inference_sections_from_ids
from danswer.context.search.utils import chunks_or_sections_to_search_docs
from danswer.context.search.utils import dedupe_documents
from danswer.context.search.utils import drop_llm_indices
from danswer.context.search.utils import relevant_sections_to_indices
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -66,6 +56,16 @@ from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
@@ -112,7 +112,6 @@ from danswer.tools.tool_implementations.search.search_tool import (
)
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.utils.logger import setup_logger
from danswer.utils.long_term_log import LongTermLogger
from danswer.utils.timing import log_generator_function_time
logger = setup_logger()
@@ -317,11 +316,6 @@ def stream_chat_message_objects(
retrieval_options = new_msg_req.retrieval_options
alternate_assistant_id = new_msg_req.alternate_assistant_id
# permanent "log" store, used primarily for debugging
long_term_logger = LongTermLogger(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
)
# use alternate persona if alternative assistant id is passed in
if alternate_assistant_id is not None:
persona = get_persona_by_id(
@@ -347,7 +341,6 @@ def stream_chat_message_objects(
persona=persona,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
long_term_logger=long_term_logger,
)
except GenAIDisabledException:
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
@@ -605,7 +598,6 @@ def stream_chat_message_objects(
additional_headers=custom_tool_additional_headers,
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)

View File

@@ -0,0 +1,115 @@
from typing_extensions import TypedDict # noreorder
from pydantic import BaseModel
from danswer.prompts.chat_tools import DANSWER_TOOL_DESCRIPTION
from danswer.prompts.chat_tools import DANSWER_TOOL_NAME
from danswer.prompts.chat_tools import TOOL_FOLLOWUP
from danswer.prompts.chat_tools import TOOL_LESS_FOLLOWUP
from danswer.prompts.chat_tools import TOOL_LESS_PROMPT
from danswer.prompts.chat_tools import TOOL_TEMPLATE
from danswer.prompts.chat_tools import USER_INPUT
class ToolInfo(TypedDict):
name: str
description: str
class DanswerChatModelOut(BaseModel):
model_raw: str
action: str
action_input: str
def call_tool(
model_actions: DanswerChatModelOut,
) -> str:
raise NotImplementedError("There are no additional tool integrations right now")
def form_user_prompt_text(
query: str,
tool_text: str | None,
hint_text: str | None,
user_input_prompt: str = USER_INPUT,
tool_less_prompt: str = TOOL_LESS_PROMPT,
) -> str:
user_prompt = tool_text or tool_less_prompt
user_prompt += user_input_prompt.format(user_input=query)
if hint_text:
if user_prompt[-1] != "\n":
user_prompt += "\n"
user_prompt += "\nHint: " + hint_text
return user_prompt.strip()
def form_tool_section_text(
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
) -> str | None:
if not tools and not retrieval_enabled:
return None
if retrieval_enabled and tools:
tools.append(
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
)
tools_intro = []
if tools:
num_tools = len(tools)
for tool in tools:
description_formatted = tool["description"].replace("\n", " ")
tools_intro.append(f"> {tool['name']}: {description_formatted}")
prefix = "Must be one of " if num_tools > 1 else "Must be "
tools_intro_text = "\n".join(tools_intro)
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
else:
return None
return template.format(
tool_overviews=tools_intro_text, tool_names=tool_names_text
).strip()
def form_tool_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_FOLLOWUP,
ignore_hint: bool = False,
) -> str:
# If multi-line query, it likely confuses the model more than helps
if "\n" not in query:
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
else:
optional_reminder = ""
if not ignore_hint and hint_text:
hint_text_spaced = f"\nHint: {hint_text}\n"
else:
hint_text_spaced = ""
return tool_followup_prompt.format(
tool_output=tool_output,
optional_reminder=optional_reminder,
hint=hint_text_spaced,
).strip()
def form_tool_less_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
) -> str:
hint = f"Hint: {hint_text}" if hint_text else ""
return tool_followup_prompt.format(
context_str=tool_output, user_query=query, hint_text=hint
).strip()

View File

@@ -234,7 +234,7 @@ except ValueError:
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
)
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 3
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 1
try:
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
if not env_value:
@@ -422,9 +422,6 @@ LOG_ALL_MODEL_INTERACTIONS = (
LOG_DANSWER_MODEL_INTERACTIONS = (
os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true"
)
LOG_INDIVIDUAL_MODEL_TOKENS = (
os.environ.get("LOG_INDIVIDUAL_MODEL_TOKENS", "").lower() == "true"
)
# If set to `true` will enable additional logs about Vespa query performance
# (time spent on finding the right docs + time spent fetching summaries from disk)
LOG_VESPA_TIMING_INFORMATION = (
@@ -493,6 +490,10 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
# JWT configuration
JWT_ALGORITHM = "HS256"
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
#####
# API Key Configs

View File

@@ -1,9 +1,9 @@
import os
PROMPTS_YAML = "./danswer/seeding/prompts.yaml"
PERSONAS_YAML = "./danswer/seeding/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/seeding/input_prompts.yaml"
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
PERSONAS_YAML = "./danswer/chat/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/chat/input_prompts.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking
@@ -17,6 +17,9 @@ MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
# ~3k input, half for docs, half for chat history + prompts
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
# For selecting a different LLM question-answering prompt format
# Valid values: default, cot, weak
QA_PROMPT_OVERRIDE = os.environ.get("QA_PROMPT_OVERRIDE") or None
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
# Capped in Vespa at 0.5
DOC_TIME_DECAY = float(
@@ -24,6 +27,8 @@ DOC_TIME_DECAY = float(
)
BASE_RECENCY_DECAY = 0.5
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
# Currently this next one is not configurable via env
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
# Note this is not in any of the deployment configs yet
# Currently only applies to search flow not chat

View File

@@ -60,6 +60,7 @@ KV_GMAIL_CRED_KEY = "gmail_app_credential"
KV_GMAIL_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
KV_GOOGLE_DRIVE_CRED_KEY = "google_drive_app_credential"
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
KV_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time"
KV_SETTINGS_KEY = "danswer_settings"
KV_CUSTOMER_UUID_KEY = "customer_uuid"
@@ -73,7 +74,7 @@ CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.

View File

@@ -70,9 +70,7 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
)
# Typically, GenAI models nowadays are at least 4K tokens
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
)
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
# Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible

View File

@@ -11,16 +11,11 @@ Connectors come in 3 different flows:
- Load Connector:
- Bulk indexes documents to reflect a point in time. This type of connector generally works by either pulling all
documents via a connector's API or loads the documents from some sort of a dump file.
- Poll Connector:
- Poll connector:
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
changes and additions since the last round of polling. This connector helps keep the document index up to date
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
documents.
- Slim Connector:
- This connector should be a lighter weight method of checking all documents in the source to see if they still exist.
- This connector should be identical to the Poll or Load Connector except that it only fetches the IDs of the documents, not the documents themselves.
- This is used by our pruning job which removes old documents from the index.
- The optional start and end datetimes can be ignored.
- Event Based connectors:
- Connectors that listen to events and update documents accordingly.
- Currently not used by the background job, this exists for future design purposes.
@@ -31,14 +26,8 @@ Refer to [interfaces.py](https://github.com/danswer-ai/danswer/blob/main/backend
and this first contributor created Pull Request for a new connector (Shoutout to Dan Brown):
[Reference Pull Request](https://github.com/danswer-ai/danswer/pull/139)
For implementing a Slim Connector, refer to the comments in this PR:
[Slim Connector PR](https://github.com/danswer-ai/danswer/pull/3303/files)
All new connectors should have tests added to the `backend/tests/daily/connectors` directory. Refer to the above PR for an example of adding tests for a new connector.
#### Implementing the new Connector
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
The connector must subclass one or more of LoadConnector, PollConnector, or EventConnector.
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of

View File

@@ -5,9 +5,9 @@ from io import BytesIO
from typing import Any
from typing import Optional
import boto3 # type: ignore
from botocore.client import Config # type: ignore
from mypy_boto3_s3 import S3Client # type: ignore
import boto3
from botocore.client import Config
from mypy_boto3_s3 import S3Client
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import BlobType

View File

@@ -7,9 +7,9 @@ from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.confluence.onyx_confluence import build_confluence_client
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
from danswer.connectors.confluence.utils import attachment_to_content
from danswer.connectors.confluence.utils import build_confluence_client
from danswer.connectors.confluence.utils import build_confluence_document_id
from danswer.connectors.confluence.utils import datetime_from_string
from danswer.connectors.confluence.utils import extract_text_from_confluence_html
@@ -51,8 +51,6 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
"restrictions.read.restrictions.group",
]
_SLIM_DOC_BATCH_SIZE = 5000
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
@@ -72,7 +70,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
) -> None:
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self._confluence_client: OnyxConfluence | None = None
self.confluence_client: OnyxConfluence | None = None
self.is_cloud = is_cloud
# Remove trailing slash from wiki_base if present
@@ -83,15 +81,15 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
if cql_query:
# if a cql_query is provided, we will use it to fetch the pages
cql_page_query = cql_query
elif space:
# if no cql_query is provided, we will use the space to fetch the pages
cql_page_query += f" and space='{quote(space)}'"
elif page_id:
# if a cql_query is not provided, we will use the page_id to fetch the page
if index_recursively:
cql_page_query += f" and ancestor='{page_id}'"
else:
# if neither a space nor a cql_query is provided, we will use the page_id to fetch the page
cql_page_query += f" and id='{page_id}'"
elif space:
# if no cql_query or page_id is provided, we will use the space to fetch the pages
cql_page_query += f" and space='{quote(space)}'"
self.cql_page_query = cql_page_query
self.cql_time_filter = ""
@@ -99,44 +97,39 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_label_filter = ""
if labels_to_skip:
labels_to_skip = list(set(labels_to_skip))
comma_separated_labels = ",".join(
f"'{quote(label)}'" for label in labels_to_skip
)
comma_separated_labels = ",".join(f"'{label}'" for label in labels_to_skip)
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
@property
def confluence_client(self) -> OnyxConfluence:
if self._confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
return self._confluence_client
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
# for a list of other hidden constructor args
self._confluence_client = build_confluence_client(
credentials=credentials,
self.confluence_client = build_confluence_client(
credentials_json=credentials,
is_cloud=self.is_cloud,
wiki_base=self.wiki_base,
)
return None
def _get_comment_string_for_page_id(self, page_id: str) -> str:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
comment_string = ""
comment_cql = f"type=comment and container='{page_id}'"
comment_cql += self.cql_label_filter
expand = ",".join(_COMMENT_EXPANSION_FIELDS)
for comment in self.confluence_client.paginated_cql_retrieval(
for comments in self.confluence_client.paginated_cql_page_retrieval(
cql=comment_cql,
expand=expand,
):
comment_string += "\nComment:\n"
comment_string += extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=comment,
fetched_titles=set(),
)
for comment in comments:
comment_string += "\nComment:\n"
comment_string += extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=comment,
)
return comment_string
@@ -148,6 +141,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
If its a page, it extracts the text, adds the comments for the document text.
If its an attachment, it just downloads the attachment and converts that into a document.
"""
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
# The url and the id are the same
object_url = build_confluence_document_id(
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
@@ -157,19 +153,16 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# Extract text from page
if confluence_object["type"] == "page":
object_text = extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=confluence_object,
fetched_titles={confluence_object.get("title", "")},
self.confluence_client, confluence_object
)
# Add comments to text
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
elif confluence_object["type"] == "attachment":
object_text = attachment_to_content(
confluence_client=self.confluence_client, attachment=confluence_object
self.confluence_client, confluence_object
)
if object_text is None:
# This only happens for attachments that are not parseable
return None
# Get space name
@@ -200,39 +193,44 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
)
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
doc_batch: list[Document] = []
confluence_page_ids: list[str] = []
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
# Fetch pages as Documents
for page in self.confluence_client.paginated_cql_retrieval(
for page_batch in self.confluence_client.paginated_cql_page_retrieval(
cql=page_query,
expand=",".join(_PAGE_EXPANSION_FIELDS),
limit=self.batch_size,
):
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
for page in page_batch:
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
# Fetch attachments as Documents
for confluence_page_id in confluence_page_ids:
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
attachment_cql += self.cql_label_filter
# TODO: maybe should add time filter as well?
for attachment in self.confluence_client.paginated_cql_retrieval(
for attachments in self.confluence_client.paginated_cql_page_retrieval(
cql=attachment_cql,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
doc = self._convert_object_to_document(attachment)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
for attachment in attachments:
doc = self._convert_object_to_document(attachment)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
@@ -257,52 +255,52 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
doc_metadata_list: list[SlimDocument] = []
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
page_query = self.cql_page_query + self.cql_label_filter
for page in self.confluence_client.cql_paginate_all_expansions(
for pages in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
perm_sync_data = {
"restrictions": page.get("restrictions", {}),
"space_key": page.get("space", {}).get("key"),
}
for page in pages:
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
perm_sync_data = {
"restrictions": page.get("restrictions", {}),
"space_key": page.get("space", {}).get("key"),
}
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
page["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
attachment["_links"]["webui"],
page["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
)
)
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
yield doc_metadata_list
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
for attachments in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
expand=restrictions_expand,
):
for attachment in attachments:
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
attachment["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
)
)
yield doc_metadata_list
doc_metadata_list = []

View File

@@ -20,10 +20,6 @@ F = TypeVar("F", bound=Callable[..., Any])
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
# https://jira.atlassian.com/browse/CONFCLOUD-76433
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
_REPLACEMENT_EXPANSIONS = "body.view.value"
class ConfluenceRateLimitError(Exception):
pass
@@ -84,7 +80,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 600
TIMEOUT = 3600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
@@ -99,10 +95,6 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return confluence_call(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
@@ -120,7 +112,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return cast(F, wrapped_call)
_DEFAULT_PAGINATION_LIMIT = 1000
_DEFAULT_PAGINATION_LIMIT = 100
class OnyxConfluence(Confluence):
@@ -149,7 +141,7 @@ class OnyxConfluence(Confluence):
def _paginate_url(
self, url_suffix: str, limit: int | None = None
) -> Iterator[dict[str, Any]]:
) -> Iterator[list[dict[str, Any]]]:
"""
This will paginate through the top level query.
"""
@@ -161,43 +153,46 @@ class OnyxConfluence(Confluence):
while url_suffix:
try:
logger.debug(f"Making confluence call to {url_suffix}")
next_response = self.get(url_suffix)
except Exception as e:
logger.warning(f"Error in confluence call to {url_suffix}")
# If the problematic expansion is in the url, replace it
# with the replacement expansion and try again
# If that fails, raise the error
if _PROBLEMATIC_EXPANSIONS not in url_suffix:
logger.exception(f"Error in confluence call to {url_suffix}")
raise e
logger.warning(
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
" and trying again."
)
url_suffix = url_suffix.replace(
_PROBLEMATIC_EXPANSIONS,
_REPLACEMENT_EXPANSIONS,
)
continue
# yield the results individually
yield from next_response.get("results", [])
logger.exception("Error in danswer_cql: \n")
raise e
yield next_response.get("results", [])
url_suffix = next_response.get("_links", {}).get("next")
def paginated_cql_retrieval(
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
return self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
group_name = quote(group_name)
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def paginated_cql_user_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
The content/search endpoint can be used to fetch pages, attachments, and comments.
"""
) -> Iterator[list[dict[str, Any]]]:
expand_string = f"&expand={expand}" if expand else ""
yield from self._paginate_url(
return self._paginate_url(
f"rest/api/search/user?cql={cql}{expand_string}", limit
)
def paginated_cql_page_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
expand_string = f"&expand={expand}" if expand else ""
return self._paginate_url(
f"rest/api/content/search?cql={cql}{expand_string}", limit
)
@@ -206,7 +201,7 @@ class OnyxConfluence(Confluence):
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
) -> Iterator[list[dict[str, Any]]]:
"""
This function will paginate through the top level query first, then
paginate through all of the expansions.
@@ -226,113 +221,6 @@ class OnyxConfluence(Confluence):
for item in data:
_traverse_and_update(item)
for confluence_object in self.paginated_cql_retrieval(cql, expand, limit):
_traverse_and_update(confluence_object)
yield confluence_object
def paginated_cql_user_retrieval(
self,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
The search/user endpoint can be used to fetch users.
It's a seperate endpoint from the content/search endpoint used only for users.
Otherwise it's very similar to the content/search endpoint.
"""
cql = "type=user"
url = "rest/api/search/user" if self.cloud else "rest/api/search"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
yield from self._paginate_url(url, limit)
def paginated_groups_by_user_retrieval(
self,
user: dict[str, Any],
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch groups.
"""
user_field = "accountId" if self.cloud else "key"
user_value = user["accountId"] if self.cloud else user["userKey"]
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
user_query = f"{user_field}={quote(user_value)}"
url = f"rest/api/user/memberof?{user_query}"
yield from self._paginate_url(url, limit)
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch groups.
"""
yield from self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch the members of a group.
THIS DOESN'T WORK FOR SERVER because it breaks when there is a slash in the group name.
E.g. neither "test/group" nor "test%2Fgroup" works for confluence.
"""
group_name = quote(group_name)
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def _validate_connector_configuration(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> None:
# test connection with direct client, no retries
confluence_client_with_minimal_retries = Confluence(
api_version="cloud" if is_cloud else "latest",
url=wiki_base.rstrip("/"),
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=6,
max_backoff_seconds=10,
)
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
if not spaces:
raise RuntimeError(
f"No spaces found at {wiki_base}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
def build_confluence_client(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> OnyxConfluence:
_validate_connector_configuration(
credentials=credentials,
is_cloud=is_cloud,
wiki_base=wiki_base,
)
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=10,
max_backoff_seconds=60,
)
for results in self.paginated_cql_page_retrieval(cql, expand, limit):
_traverse_and_update(results)
yield results

View File

@@ -2,7 +2,6 @@ import io
from datetime import datetime
from datetime import timezone
from typing import Any
from urllib.parse import quote
import bs4
@@ -72,9 +71,7 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],
fetched_titles: set[str],
confluence_client: OnyxConfluence, confluence_object: dict[str, Any]
) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
@@ -82,7 +79,7 @@ def extract_text_from_confluence_html(
Args:
confluence_object (dict): The confluence object as a dict
confluence_client (Confluence): Confluence client
fetched_titles (set[str]): The titles of the pages that have already been fetched
Returns:
str: loaded and formated Confluence page
"""
@@ -104,72 +101,38 @@ def extract_text_from_confluence_html(
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ac:structured-macro"):
# Here, we only want to process page within page macros
if html_page_reference.attrs.get("ac:name") != "include":
continue
page_data = html_page_reference.find("ri:page")
if not page_data:
logger.warning(
f"Skipping retrieval of {html_page_reference} because because page data is missing"
)
continue
page_title = page_data.attrs.get("ri:content-title")
if not page_title:
# only fetch pages that have a title
logger.warning(
f"Skipping retrieval of {html_page_reference} because it has no title"
)
continue
if page_title in fetched_titles:
# prevent recursive fetching of pages
logger.debug(f"Skipping {page_title} because it has already been fetched")
continue
fetched_titles.add(page_title)
for html_page_reference in soup.findAll("ri:page"):
# Wrap this in a try-except because there are some pages that might not exist
try:
page_query = f"type=page and title='{quote(page_title)}'"
page_title = html_page_reference.attrs["ri:content-title"]
if not page_title:
continue
page_query = f"type=page and title='{page_title}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page in confluence_client.paginated_cql_retrieval(
for page_batch in confluence_client.paginated_cql_page_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page
page_contents = page_batch[0]
break
except Exception as e:
except Exception:
logger.warning(
f"Error getting page contents for object {confluence_object}: {e}"
f"Error getting page contents for object {confluence_object}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client=confluence_client,
confluence_object=page_contents,
fetched_titles=fetched_titles,
confluence_client, page_contents
)
html_page_reference.replaceWith(text_from_page)
for html_link_body in soup.findAll("ac:link-body"):
# This extracts the text from inline links in the page so they can be
# represented in the document text as plain text
try:
text_from_link = html_link_body.text
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
except Exception as e:
logger.warning(f"Error processing ac:link-body: {e}")
return format_document_soup(soup)
@@ -269,3 +232,20 @@ def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
def build_confluence_client(
credentials_json: dict[str, Any], is_cloud: bool, wiki_base: str
) -> OnyxConfluence:
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials_json["confluence_username"] if is_cloud else None,
password=credentials_json["confluence_access_token"] if is_cloud else None,
token=credentials_json["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=60,
max_backoff_seconds=60,
)

View File

@@ -1,8 +1,8 @@
import os
from collections.abc import Iterable
from datetime import datetime
from datetime import timezone
from typing import Any
from urllib.parse import urlparse
from jira import JIRA
from jira.resources import Issue
@@ -12,93 +12,129 @@ from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.danswer_jira.utils import best_effort_basic_expert_info
from danswer.connectors.danswer_jira.utils import best_effort_get_field_from_issue
from danswer.connectors.danswer_jira.utils import build_jira_client
from danswer.connectors.danswer_jira.utils import build_jira_url
from danswer.connectors.danswer_jira.utils import extract_jira_project
from danswer.connectors.danswer_jira.utils import extract_text_from_adf
from danswer.connectors.danswer_jira.utils import get_comment_strs
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.utils.logger import setup_logger
logger = setup_logger()
PROJECT_URL_PAT = "projects"
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
_JIRA_SLIM_PAGE_SIZE = 500
_JIRA_FULL_PAGE_SIZE = 50
def _paginate_jql_search(
jira_client: JIRA,
jql: str,
max_results: int,
fields: str | None = None,
) -> Iterable[Issue]:
start = 0
while True:
logger.debug(
f"Fetching Jira issues with JQL: {jql}, "
f"starting at {start}, max results: {max_results}"
)
issues = jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields,
)
def extract_jira_project(url: str) -> tuple[str, str]:
parsed_url = urlparse(url)
jira_base = parsed_url.scheme + "://" + parsed_url.netloc
for issue in issues:
if isinstance(issue, Issue):
yield issue
else:
raise Exception(f"Found Jira object not of type Issue: {issue}")
# Split the path by '/' and find the position of 'projects' to get the project name
split_path = parsed_url.path.split("/")
if PROJECT_URL_PAT in split_path:
project_pos = split_path.index(PROJECT_URL_PAT)
if len(split_path) > project_pos + 1:
jira_project = split_path[project_pos + 1]
else:
raise ValueError("No project name found in the URL")
else:
raise ValueError("'projects' not found in the URL")
if len(issues) < max_results:
break
return jira_base, jira_project
start += max_results
def extract_text_from_adf(adf: dict | None) -> str:
"""Extracts plain text from Atlassian Document Format:
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
WARNING: This function is incomplete and will e.g. skip lists!
"""
texts = []
if adf is not None and "content" in adf:
for block in adf["content"]:
if "content" in block:
for item in block["content"]:
if item["type"] == "text":
texts.append(item["text"])
return " ".join(texts)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def _get_comment_strs(
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
comment_strs = []
for comment in jira.fields.comment.comments:
try:
body_text = (
comment.body
if JIRA_API_VERSION == "2"
else extract_text_from_adf(comment.raw["body"])
)
if (
hasattr(comment, "author")
and hasattr(comment.author, "emailAddress")
and comment.author.emailAddress in comment_email_blacklist
):
continue # Skip adding comment if author's email is in blacklist
comment_strs.append(body_text)
except Exception as e:
logger.error(f"Failed to process comment due to an error: {e}")
continue
return comment_strs
def fetch_jira_issues_batch(
jira_client: JIRA,
jql: str,
batch_size: int,
start_index: int,
jira_client: JIRA,
batch_size: int = INDEX_BATCH_SIZE,
comment_email_blacklist: tuple[str, ...] = (),
labels_to_skip: set[str] | None = None,
) -> Iterable[Document]:
for issue in _paginate_jql_search(
jira_client=jira_client,
jql=jql,
max_results=batch_size,
):
if labels_to_skip:
if any(label in issue.fields.labels for label in labels_to_skip):
logger.info(
f"Skipping {issue.key} because it has a label to skip. Found "
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
)
continue
) -> tuple[list[Document], int]:
doc_batch = []
batch = jira_client.search_issues(
jql,
startAt=start_index,
maxResults=batch_size,
)
for jira in batch:
if type(jira) != Issue:
logger.warning(f"Found Jira object not of type Issue {jira}")
continue
if labels_to_skip and any(
label in jira.fields.labels for label in labels_to_skip
):
logger.info(
f"Skipping {jira.key} because it has a label to skip. Found "
f"labels: {jira.fields.labels}. Labels to skip: {labels_to_skip}."
)
continue
description = (
issue.fields.description
jira.fields.description
if JIRA_API_VERSION == "2"
else extract_text_from_adf(issue.raw["fields"]["description"])
)
comments = get_comment_strs(
issue=issue,
comment_email_blacklist=comment_email_blacklist,
else extract_text_from_adf(jira.raw["fields"]["description"])
)
comments = _get_comment_strs(jira, comment_email_blacklist)
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
@@ -106,53 +142,66 @@ def fetch_jira_issues_batch(
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
logger.info(
f"Skipping {issue.key} because it exceeds the maximum size of "
f"Skipping {jira.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
)
continue
page_url = f"{jira_client.client_info()}/browse/{issue.key}"
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
people = set()
try:
creator = best_effort_get_field_from_issue(issue, "creator")
if basic_expert_info := best_effort_basic_expert_info(creator):
people.add(basic_expert_info)
people.add(
BasicExpertInfo(
display_name=jira.fields.creator.displayName,
email=jira.fields.creator.emailAddress,
)
)
except Exception:
# Author should exist but if not, doesn't matter
pass
try:
assignee = best_effort_get_field_from_issue(issue, "assignee")
if basic_expert_info := best_effort_basic_expert_info(assignee):
people.add(basic_expert_info)
people.add(
BasicExpertInfo(
display_name=jira.fields.assignee.displayName, # type: ignore
email=jira.fields.assignee.emailAddress, # type: ignore
)
)
except Exception:
# Author should exist but if not, doesn't matter
pass
metadata_dict = {}
if priority := best_effort_get_field_from_issue(issue, "priority"):
priority = best_effort_get_field_from_issue(jira, "priority")
if priority:
metadata_dict["priority"] = priority.name
if status := best_effort_get_field_from_issue(issue, "status"):
status = best_effort_get_field_from_issue(jira, "status")
if status:
metadata_dict["status"] = status.name
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
resolution = best_effort_get_field_from_issue(jira, "resolution")
if resolution:
metadata_dict["resolution"] = resolution.name
if labels := best_effort_get_field_from_issue(issue, "labels"):
labels = best_effort_get_field_from_issue(jira, "labels")
if labels:
metadata_dict["label"] = labels
yield Document(
id=page_url,
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=issue.fields.summary,
doc_updated_at=time_str_to_utc(issue.fields.updated),
primary_owners=list(people) or None,
# TODO add secondary_owners (commenters) if needed
metadata=metadata_dict,
doc_batch.append(
Document(
id=page_url,
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=jira.fields.summary,
doc_updated_at=time_str_to_utc(jira.fields.updated),
primary_owners=list(people) or None,
# TODO add secondary_owners (commenters) if needed
metadata=metadata_dict,
)
)
return doc_batch, len(batch)
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
class JiraConnector(LoadConnector, PollConnector):
def __init__(
self,
jira_project_url: str,
@@ -164,8 +213,8 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
) -> None:
self.batch_size = batch_size
self.jira_base, self._jira_project = extract_jira_project(jira_project_url)
self._jira_client: JIRA | None = None
self.jira_base, self.jira_project = extract_jira_project(jira_project_url)
self.jira_client: JIRA | None = None
self._comment_email_blacklist = comment_email_blacklist or []
self.labels_to_skip = set(labels_to_skip)
@@ -174,45 +223,54 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
def comment_email_blacklist(self) -> tuple:
return tuple(email.strip() for email in self._comment_email_blacklist)
@property
def jira_client(self) -> JIRA:
if self._jira_client is None:
raise ConnectorMissingCredentialError("Jira")
return self._jira_client
@property
def quoted_jira_project(self) -> str:
# Quote the project name to handle reserved words
return f'"{self._jira_project}"'
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._jira_client = build_jira_client(
credentials=credentials,
jira_base=self.jira_base,
)
api_token = credentials["jira_api_token"]
# if user provide an email we assume it's cloud
if "jira_user_email" in credentials:
email = credentials["jira_user_email"]
self.jira_client = JIRA(
basic_auth=(email, api_token),
server=self.jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
else:
self.jira_client = JIRA(
token_auth=api_token,
server=self.jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
jql = f"project = {self.quoted_jira_project}"
if self.jira_client is None:
raise ConnectorMissingCredentialError("Jira")
document_batch = []
for doc in fetch_jira_issues_batch(
jira_client=self.jira_client,
jql=jql,
batch_size=_JIRA_FULL_PAGE_SIZE,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
document_batch.append(doc)
if len(document_batch) >= self.batch_size:
yield document_batch
document_batch = []
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
start_ind = 0
while True:
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
jql=f"project = {quoted_project}",
start_index=start_ind,
jira_client=self.jira_client,
batch_size=self.batch_size,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
)
yield document_batch
if doc_batch:
yield doc_batch
start_ind += fetched_batch_size
if fetched_batch_size < self.batch_size:
break
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.jira_client is None:
raise ConnectorMissingCredentialError("Jira")
start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
"%Y-%m-%d %H:%M"
)
@@ -220,54 +278,31 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
"%Y-%m-%d %H:%M"
)
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
jql = (
f"project = {self.quoted_jira_project} AND "
f"project = {quoted_project} AND "
f"updated >= '{start_date_str}' AND "
f"updated <= '{end_date_str}'"
)
document_batch = []
for doc in fetch_jira_issues_batch(
jira_client=self.jira_client,
jql=jql,
batch_size=_JIRA_FULL_PAGE_SIZE,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
document_batch.append(doc)
if len(document_batch) >= self.batch_size:
yield document_batch
document_batch = []
yield document_batch
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
jql = f"project = {self.quoted_jira_project}"
slim_doc_batch = []
for issue in _paginate_jql_search(
jira_client=self.jira_client,
jql=jql,
max_results=_JIRA_SLIM_PAGE_SIZE,
fields="key",
):
issue_key = best_effort_get_field_from_issue(issue, "key")
id = build_jira_url(self.jira_client, issue_key)
slim_doc_batch.append(
SlimDocument(
id=id,
perm_sync_data=None,
)
start_ind = 0
while True:
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
jql=jql,
start_index=start_ind,
jira_client=self.jira_client,
batch_size=self.batch_size,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
)
if len(slim_doc_batch) >= _JIRA_SLIM_PAGE_SIZE:
yield slim_doc_batch
slim_doc_batch = []
yield slim_doc_batch
if doc_batch:
yield doc_batch
start_ind += fetched_batch_size
if fetched_batch_size < self.batch_size:
break
if __name__ == "__main__":

View File

@@ -1,136 +1,17 @@
"""Module with custom fields processing functions"""
import os
from typing import Any
from typing import List
from urllib.parse import urlparse
from jira import JIRA
from jira.resources import CustomFieldOption
from jira.resources import Issue
from jira.resources import User
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.logger import setup_logger
logger = setup_logger()
PROJECT_URL_PAT = "projects"
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
def best_effort_basic_expert_info(obj: Any) -> BasicExpertInfo | None:
display_name = None
email = None
if hasattr(obj, "display_name"):
display_name = obj.display_name
else:
display_name = obj.get("displayName")
if hasattr(obj, "emailAddress"):
email = obj.emailAddress
else:
email = obj.get("emailAddress")
if not email and not display_name:
return None
return BasicExpertInfo(display_name=display_name, email=email)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def extract_text_from_adf(adf: dict | None) -> str:
"""Extracts plain text from Atlassian Document Format:
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
WARNING: This function is incomplete and will e.g. skip lists!
"""
texts = []
if adf is not None and "content" in adf:
for block in adf["content"]:
if "content" in block:
for item in block["content"]:
if item["type"] == "text":
texts.append(item["text"])
return " ".join(texts)
def build_jira_url(jira_client: JIRA, issue_key: str) -> str:
return f"{jira_client.client_info()}/browse/{issue_key}"
def build_jira_client(credentials: dict[str, Any], jira_base: str) -> JIRA:
api_token = credentials["jira_api_token"]
# if user provide an email we assume it's cloud
if "jira_user_email" in credentials:
email = credentials["jira_user_email"]
return JIRA(
basic_auth=(email, api_token),
server=jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
else:
return JIRA(
token_auth=api_token,
server=jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
def extract_jira_project(url: str) -> tuple[str, str]:
parsed_url = urlparse(url)
jira_base = parsed_url.scheme + "://" + parsed_url.netloc
# Split the path by '/' and find the position of 'projects' to get the project name
split_path = parsed_url.path.split("/")
if PROJECT_URL_PAT in split_path:
project_pos = split_path.index(PROJECT_URL_PAT)
if len(split_path) > project_pos + 1:
jira_project = split_path[project_pos + 1]
else:
raise ValueError("No project name found in the URL")
else:
raise ValueError("'projects' not found in the URL")
return jira_base, jira_project
def get_comment_strs(
issue: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
comment_strs = []
for comment in issue.fields.comment.comments:
try:
body_text = (
comment.body
if JIRA_API_VERSION == "2"
else extract_text_from_adf(comment.raw["body"])
)
if (
hasattr(comment, "author")
and hasattr(comment.author, "emailAddress")
and comment.author.emailAddress in comment_email_blacklist
):
continue # Skip adding comment if author's email is in blacklist
comment_strs.append(body_text)
except Exception as e:
logger.error(f"Failed to process comment due to an error: {e}")
continue
return comment_strs
class CustomFieldExtractor:
@staticmethod
def _process_custom_field_value(value: Any) -> str:

View File

@@ -15,7 +15,6 @@ from danswer.connectors.google_drive.doc_conversion import (
convert_drive_item_to_document,
)
from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files
from danswer.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from danswer.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from danswer.connectors.google_drive.models import GoogleDriveFileType
@@ -83,31 +82,12 @@ def _process_files_batch(
yield doc_batch
def _clean_requested_drive_ids(
requested_drive_ids: set[str],
requested_folder_ids: set[str],
all_drive_ids_available: set[str],
) -> tuple[set[str], set[str]]:
invalid_requested_drive_ids = requested_drive_ids - all_drive_ids_available
filtered_folder_ids = requested_folder_ids - all_drive_ids_available
if invalid_requested_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_requested_drive_ids}"
)
logger.warning("Checking for folder access instead...")
filtered_folder_ids.update(invalid_requested_drive_ids)
valid_requested_drive_ids = requested_drive_ids - invalid_requested_drive_ids
return valid_requested_drive_ids, filtered_folder_ids
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
include_shared_drives: bool = False,
include_my_drives: bool = False,
include_files_shared_with_me: bool = False,
include_shared_drives: bool = True,
shared_drive_urls: str | None = None,
include_my_drives: bool = True,
my_drive_emails: str | None = None,
shared_folder_urls: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
@@ -140,36 +120,22 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if (
not include_shared_drives
and not include_my_drives
and not include_files_shared_with_me
and not shared_folder_urls
and not my_drive_emails
and not shared_drive_urls
):
raise ValueError(
"Nothing to index. Please specify at least one of the following: "
"include_shared_drives, include_my_drives, include_files_shared_with_me, "
"shared_folder_urls, or my_drive_emails"
"At least one of include_shared_drives, include_my_drives,"
" or shared_folder_urls must be true"
)
self.batch_size = batch_size
specific_requests_made = False
if bool(shared_drive_urls) or bool(my_drive_emails) or bool(shared_folder_urls):
specific_requests_made = True
self.include_files_shared_with_me = (
False if specific_requests_made else include_files_shared_with_me
)
self.include_my_drives = False if specific_requests_made else include_my_drives
self.include_shared_drives = (
False if specific_requests_made else include_shared_drives
)
self.include_shared_drives = include_shared_drives
shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls)
self._requested_shared_drive_ids = set(
_extract_ids_from_urls(shared_drive_url_list)
)
self.include_my_drives = include_my_drives
self._requested_my_drive_emails = set(
_extract_str_list_from_comma_str(my_drive_emails)
)
@@ -259,20 +225,26 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
creds=self.creds,
user_email=self.primary_admin_email,
)
is_service_account = isinstance(self.creds, ServiceAccountCredentials)
all_drive_ids = set()
# We don't want to fail if we're using OAuth because you can
# access your my drive as a non admin user in an org still
ignore_fetch_failure = isinstance(self.creds, OAuthCredentials)
for drive in execute_paginated_retrieval(
retrieval_function=primary_drive_service.drives().list,
list_key="drives",
useDomainAdminAccess=is_service_account,
continue_on_404_or_403=ignore_fetch_failure,
useDomainAdminAccess=True,
fields="drives(id)",
):
all_drive_ids.add(drive["id"])
if not all_drive_ids:
logger.warning(
"No drives found even though we are indexing shared drives was requested."
"No drives found. This is likely because oauth user "
"is not an admin and cannot view all drive IDs. "
"Continuing with only the shared drive IDs specified in the config."
)
all_drive_ids = set(self._requested_shared_drive_ids)
return all_drive_ids
@@ -289,9 +261,14 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
# if we are including my drives, try to get the current user's my
# drive if any of the following are true:
# - include_my_drives is true
# - no specific emails were requested
# - the current user's email is in the requested emails
if self.include_my_drives or user_email in self._requested_my_drive_emails:
# - we are using OAuth (in which case we assume that is the only email we will try)
if self.include_my_drives and (
not self._requested_my_drive_emails
or user_email in self._requested_my_drive_emails
or isinstance(self.creds, OAuthCredentials)
):
yield from get_all_files_in_my_drive(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
@@ -322,7 +299,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
end=end,
)
def _manage_service_account_retrieval(
def _fetch_drive_items(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
@@ -332,16 +309,29 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
all_drive_ids: set[str] = self._get_all_drive_ids()
drive_ids_to_retrieve: set[str] = set()
folder_ids_to_retrieve: set[str] = set()
if self._requested_shared_drive_ids or self._requested_folder_ids:
drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids(
requested_drive_ids=self._requested_shared_drive_ids,
requested_folder_ids=self._requested_folder_ids,
all_drive_ids_available=all_drive_ids,
# remove drive ids from the folder ids because they are queried differently
filtered_folder_ids = self._requested_folder_ids - all_drive_ids
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
invalid_drive_ids = self._requested_shared_drive_ids - all_drive_ids
if invalid_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
)
elif self.include_shared_drives:
drive_ids_to_retrieve = all_drive_ids
logger.warning("Checking for folder access instead...")
filtered_folder_ids.update(invalid_drive_ids)
# If including shared drives, use the requested IDs if provided,
# otherwise use all drive IDs
filtered_drive_ids = set()
if self.include_shared_drives:
if self._requested_shared_drive_ids:
# Remove invalid drive IDs from requested IDs
filtered_drive_ids = (
self._requested_shared_drive_ids - invalid_drive_ids
)
else:
filtered_drive_ids = all_drive_ids
# Process users in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=10) as executor:
@@ -350,8 +340,8 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
self._impersonate_user_for_retrieval,
email,
is_slim,
drive_ids_to_retrieve,
folder_ids_to_retrieve,
filtered_drive_ids,
filtered_folder_ids,
start,
end,
): email
@@ -363,101 +353,13 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
yield from future.result()
remaining_folders = (
drive_ids_to_retrieve | folder_ids_to_retrieve
filtered_drive_ids | filtered_folder_ids
) - self._retrieved_ids
if remaining_folders:
logger.warning(
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
)
def _manage_oauth_retrieval(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
drive_service = get_drive_service(self.creds, self.primary_admin_email)
if self.include_files_shared_with_me or self.include_my_drives:
yield from get_all_files_for_oauth(
service=drive_service,
include_files_shared_with_me=self.include_files_shared_with_me,
include_my_drives=self.include_my_drives,
include_shared_drives=self.include_shared_drives,
is_slim=is_slim,
start=start,
end=end,
)
all_requested = (
self.include_files_shared_with_me
and self.include_my_drives
and self.include_shared_drives
)
if all_requested:
# If all 3 are true, we already yielded from get_all_files_for_oauth
return
all_drive_ids = self._get_all_drive_ids()
drive_ids_to_retrieve: set[str] = set()
folder_ids_to_retrieve: set[str] = set()
if self._requested_shared_drive_ids or self._requested_folder_ids:
drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids(
requested_drive_ids=self._requested_shared_drive_ids,
requested_folder_ids=self._requested_folder_ids,
all_drive_ids_available=all_drive_ids,
)
elif self.include_shared_drives:
drive_ids_to_retrieve = all_drive_ids
for drive_id in drive_ids_to_retrieve:
yield from get_files_in_shared_drive(
service=drive_service,
drive_id=drive_id,
is_slim=is_slim,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
# Even if no folders were requested, we still check if any drives were requested
# that could be folders.
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
for folder_id in remaining_folders:
yield from crawl_folders_for_files(
service=drive_service,
parent_id=folder_id,
traversed_parent_ids=self._retrieved_ids,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
remaining_folders = (
drive_ids_to_retrieve | folder_ids_to_retrieve
) - self._retrieved_ids
if remaining_folders:
logger.warning(
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
)
def _fetch_drive_items(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
retrieval_method = (
self._manage_service_account_retrieval
if isinstance(self.creds, ServiceAccountCredentials)
else self._manage_oauth_retrieval
)
return retrieval_method(
is_slim=is_slim,
start=start,
end=end,
)
def _extract_docs_from_google_drive(
self,
start: SecondsSinceUnixEpoch | None = None,

View File

@@ -2,7 +2,6 @@ import io
from datetime import datetime
from datetime import timezone
from googleapiclient.discovery import build # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
@@ -49,67 +48,6 @@ def _extract_sections_basic(
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
try:
if mime_type == GDriveMimeType.SPREADSHEET.value:
try:
sheets_service = build(
"sheets", "v4", credentials=service._http.credentials
)
spreadsheet = (
sheets_service.spreadsheets()
.get(spreadsheetId=file["id"])
.execute()
)
sections = []
for sheet in spreadsheet["sheets"]:
sheet_name = sheet["properties"]["title"]
sheet_id = sheet["properties"]["sheetId"]
# Get sheet dimensions
grid_properties = sheet["properties"].get("gridProperties", {})
row_count = grid_properties.get("rowCount", 1000)
column_count = grid_properties.get("columnCount", 26)
# Convert column count to letter (e.g., 26 -> Z, 27 -> AA)
end_column = ""
while column_count:
column_count, remainder = divmod(column_count - 1, 26)
end_column = chr(65 + remainder) + end_column
range_name = f"'{sheet_name}'!A1:{end_column}{row_count}"
try:
result = (
sheets_service.spreadsheets()
.values()
.get(spreadsheetId=file["id"], range=range_name)
.execute()
)
values = result.get("values", [])
if values:
text = f"Sheet: {sheet_name}\n"
for row in values:
text += "\t".join(str(cell) for cell in row) + "\n"
sections.append(
Section(
link=f"{link}#gid={sheet_id}",
text=text,
)
)
except HttpError as e:
logger.warning(
f"Error fetching data for sheet '{sheet_name}': {e}"
)
continue
return sections
except Exception as e:
logger.warning(
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
" Falling back to basic extraction."
)
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
@@ -127,7 +65,6 @@ def _extract_sections_basic(
.decode("utf-8")
)
return [Section(link=link, text=text)]
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,

View File

@@ -140,8 +140,8 @@ def get_files_in_shared_drive(
) -> Iterator[GoogleDriveFileType]:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
@@ -152,7 +152,7 @@ def get_files_in_shared_drive(
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields="nextPageToken, files(id)",
q=folder_query,
q=query,
):
update_traversed_ids_func(file["id"])
found_folders = True
@@ -160,9 +160,9 @@ def get_files_in_shared_drive(
update_traversed_ids_func(drive_id)
# Get all files in the shared drive
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += _generate_time_range_filter(start, end)
query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
@@ -172,7 +172,7 @@ def get_files_in_shared_drive(
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=file_query,
q=query,
)
@@ -185,16 +185,14 @@ def get_all_files_in_my_drive(
) -> Iterator[GoogleDriveFileType]:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
folder_query += " and 'me' in owners"
query = "trashed = false and 'me' in owners"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=folder_query,
q=query,
):
update_traversed_ids_func(file["id"])
found_folders = True
@@ -202,52 +200,18 @@ def get_all_files_in_my_drive(
update_traversed_ids_func(get_root_folder_id(service))
# Then get the files
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += " and 'me' in owners"
file_query += _generate_time_range_filter(start, end)
query = "trashed = false and 'me' in owners"
query += _generate_time_range_filter(start, end)
fields = "files(id, name, mimeType, webViewLink, modifiedTime, createdTime)"
if not is_slim:
fields += ", files(permissions, permissionIds, owners)"
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=file_query,
)
def get_all_files_for_oauth(
service: Any,
include_files_shared_with_me: bool,
include_my_drives: bool,
# One of the above 2 should be true
include_shared_drives: bool,
is_slim: bool = False,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
should_get_all = (
include_shared_drives and include_my_drives and include_files_shared_with_me
)
corpora = "allDrives" if should_get_all else "user"
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += _generate_time_range_filter(start, end)
if not should_get_all:
if include_files_shared_with_me and not include_my_drives:
file_query += " and not 'me' in owners"
if not include_files_shared_with_me and include_my_drives:
file_query += " and 'me' in owners"
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora=corpora,
includeItemsFromAllDrives=should_get_all,
supportsAllDrives=should_get_all,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=file_query,
q=query,
)

View File

@@ -12,15 +12,12 @@ from dateutil import parser
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.utils.logger import setup_logger
@@ -31,8 +28,6 @@ logger = setup_logger()
SLAB_GRAPHQL_MAX_TRIES = 10
SLAB_API_URL = "https://api.slab.com/v1/graphql"
_SLIM_BATCH_SIZE = 1000
def run_graphql_request(
graphql_query: dict, bot_token: str, max_tries: int = SLAB_GRAPHQL_MAX_TRIES
@@ -163,26 +158,21 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
return urljoin(urljoin(base_url, "posts/"), url_id)
class SlabConnector(LoadConnector, PollConnector, SlimConnector):
class SlabConnector(LoadConnector, PollConnector):
def __init__(
self,
base_url: str,
batch_size: int = INDEX_BATCH_SIZE,
slab_bot_token: str | None = None,
) -> None:
self.base_url = base_url
self.batch_size = batch_size
self._slab_bot_token: str | None = None
self.slab_bot_token = slab_bot_token
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._slab_bot_token = credentials["slab_bot_token"]
self.slab_bot_token = credentials["slab_bot_token"]
return None
@property
def slab_bot_token(self) -> str:
if self._slab_bot_token is None:
raise ConnectorMissingCredentialError("Slab")
return self._slab_bot_token
def _iterate_posts(
self, time_filter: Callable[[datetime], bool] | None = None
) -> GenerateDocumentsOutput:
@@ -237,21 +227,3 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
yield from self._iterate_posts(
time_filter=lambda t: start_time <= t <= end_time
)
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
slim_doc_batch: list[SlimDocument] = []
for post_id in get_all_post_ids(self.slab_bot_token):
slim_doc_batch.append(
SlimDocument(
id=post_id,
)
)
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
yield slim_doc_batch
slim_doc_batch = []
if slim_doc_batch:
yield slim_doc_batch

View File

@@ -102,21 +102,13 @@ def _get_tickets(
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
# Skip fetching if author_id is invalid
if not author_id or author_id == "-1":
return None
try:
author_data = client.make_request(f"users/{author_id}", {})
user = author_data.get("user")
return (
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
if user and user.get("name") and user.get("email")
else None
)
except requests.exceptions.HTTPError:
# Handle any API errors gracefully
return None
author_data = client.make_request(f"users/{author_id}", {})
user = author_data.get("user")
return (
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
if user and user.get("name") and user.get("email")
else None
)
def _article_to_document(

View File

@@ -18,30 +18,20 @@ from slack_sdk.models.blocks.block_elements import ImageElement
from danswer.chat.models import DanswerQuote
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
from danswer.context.search.models import SavedSearchDoc
from danswer.danswerbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.formatting import format_slack_message
from danswer.danswerbot.slack.icons import source_to_github_img_link
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import build_continue_in_web_ui_id
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import remove_slack_text_interactions
from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
from danswer.db.chat import get_chat_session_by_message_id
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.search.models import SavedSearchDoc
from danswer.utils.text_processing import decode_escapes
from danswer.utils.text_processing import replace_whitespaces_w_space
@@ -111,12 +101,12 @@ def _split_text(text: str, limit: int = 3000) -> list[str]:
return chunks
def _clean_markdown_link_text(text: str) -> str:
def clean_markdown_link_text(text: str) -> str:
# Remove any newlines within the text
return text.replace("\n", " ").strip()
def _build_qa_feedback_block(
def build_qa_feedback_block(
message_id: int, feedback_reminder_id: str | None = None
) -> Block:
return ActionsBlock(
@@ -125,6 +115,7 @@ def _build_qa_feedback_block(
ButtonElement(
action_id=LIKE_BLOCK_ACTION_ID,
text="👍 Helpful",
style="primary",
value=feedback_reminder_id,
),
ButtonElement(
@@ -164,7 +155,7 @@ def get_document_feedback_blocks() -> Block:
)
def _build_doc_feedback_block(
def build_doc_feedback_block(
message_id: int,
document_id: str,
document_rank: int,
@@ -191,7 +182,7 @@ def get_restate_blocks(
]
def _build_documents_blocks(
def build_documents_blocks(
documents: list[SavedSearchDoc],
message_id: int | None,
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
@@ -232,7 +223,7 @@ def _build_documents_blocks(
feedback: ButtonElement | dict = {}
if message_id is not None:
feedback = _build_doc_feedback_block(
feedback = build_doc_feedback_block(
message_id=message_id,
document_id=d.document_id,
document_rank=rank,
@@ -250,7 +241,7 @@ def _build_documents_blocks(
return section_blocks
def _build_sources_blocks(
def build_sources_blocks(
cited_documents: list[tuple[int, SavedSearchDoc]],
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
) -> list[Block]:
@@ -295,7 +286,7 @@ def _build_sources_blocks(
+ ([days_ago_str] if days_ago_str else [])
)
document_title = _clean_markdown_link_text(doc_sem_id)
document_title = clean_markdown_link_text(doc_sem_id)
img_link = source_to_github_img_link(d.source_type)
section_blocks.append(
@@ -326,50 +317,7 @@ def _build_sources_blocks(
return section_blocks
def _priority_ordered_documents_blocks(
answer: OneShotQAResponse,
) -> list[Block]:
docs_response = answer.docs if answer.docs else None
top_docs = docs_response.top_documents if docs_response else []
llm_doc_inds = answer.llm_selected_doc_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
if not priority_ordered_docs:
return []
document_blocks = _build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
if document_blocks:
document_blocks = [DividerBlock()] + document_blocks
return document_blocks
def _build_citations_blocks(
answer: OneShotQAResponse,
) -> list[Block]:
docs_response = answer.docs if answer.docs else None
top_docs = docs_response.top_documents if docs_response else []
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = _build_sources_blocks(cited_documents=cited_docs)
return citations_block
def _build_quotes_block(
def build_quotes_block(
quotes: list[DanswerQuote],
) -> list[Block]:
quote_lines: list[str] = []
@@ -411,70 +359,58 @@ def _build_quotes_block(
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
def _build_qa_response_blocks(
answer: OneShotQAResponse,
def build_qa_response_blocks(
message_id: int | None,
answer: str | None,
quotes: list[DanswerQuote] | None,
source_filters: list[DocumentSource] | None,
time_cutoff: datetime | None,
favor_recent: bool,
skip_quotes: bool = False,
process_message_for_citations: bool = False,
skip_ai_feedback: bool = False,
feedback_reminder_id: str | None = None,
) -> list[Block]:
retrieval_info = answer.docs
if not retrieval_info:
# This should not happen, even with no docs retrieved, there is still info returned
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
quotes = answer.quotes.quotes if answer.quotes else None
if DISABLE_GENERATIVE_AI:
return []
quotes_blocks: list[Block] = []
filter_block: Block | None = None
if (
retrieval_info.applied_time_cutoff
or retrieval_info.recency_bias_multiplier > 1
or retrieval_info.applied_source_filters
):
if time_cutoff or favor_recent or source_filters:
filter_text = "Filters: "
if retrieval_info.applied_source_filters:
sources_str = ", ".join(
[s.value for s in retrieval_info.applied_source_filters]
)
if source_filters:
sources_str = ", ".join([s.value for s in source_filters])
filter_text += f"`Sources in [{sources_str}]`"
if (
retrieval_info.applied_time_cutoff
or retrieval_info.recency_bias_multiplier > 1
):
if time_cutoff or favor_recent:
filter_text += " and "
if retrieval_info.applied_time_cutoff is not None:
time_str = retrieval_info.applied_time_cutoff.strftime("%b %d, %Y")
if time_cutoff is not None:
time_str = time_cutoff.strftime("%b %d, %Y")
filter_text += f"`Docs Updated >= {time_str}` "
if retrieval_info.recency_bias_multiplier > 1:
if retrieval_info.applied_time_cutoff is not None:
if favor_recent:
if time_cutoff is not None:
filter_text += "+ "
filter_text += "`Prioritize Recently Updated Docs`"
filter_block = SectionBlock(text=f"_{filter_text}_")
if not formatted_answer:
if not answer:
answer_blocks = [
SectionBlock(
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
)
]
else:
answer_processed = decode_escapes(
remove_slack_text_interactions(formatted_answer)
)
answer_processed = decode_escapes(remove_slack_text_interactions(answer))
if process_message_for_citations:
answer_processed = _process_citations_for_slack(answer_processed)
answer_blocks = [
SectionBlock(text=text) for text in _split_text(answer_processed)
]
if quotes:
quotes_blocks = _build_quotes_block(quotes)
quotes_blocks = build_quotes_block(quotes)
# if no quotes OR `_build_quotes_block()` did not give back any blocks
# if no quotes OR `build_quotes_block()` did not give back any blocks
if not quotes_blocks:
quotes_blocks = [
SectionBlock(
@@ -489,37 +425,20 @@ def _build_qa_response_blocks(
response_blocks.extend(answer_blocks)
if message_id is not None and not skip_ai_feedback:
response_blocks.append(
build_qa_feedback_block(
message_id=message_id, feedback_reminder_id=feedback_reminder_id
)
)
if not skip_quotes:
response_blocks.extend(quotes_blocks)
return response_blocks
def _build_continue_in_web_ui_block(
tenant_id: str | None,
message_id: int | None,
) -> Block:
if message_id is None:
raise ValueError("No message id provided to build continue in web ui block")
with get_session_with_tenant(tenant_id) as db_session:
chat_session = get_chat_session_by_message_id(
db_session=db_session,
message_id=message_id,
)
return ActionsBlock(
block_id=build_continue_in_web_ui_id(message_id),
elements=[
ButtonElement(
action_id=CONTINUE_IN_WEB_UI_ACTION_ID,
text="Continue Chat in Danswer!",
style="primary",
url=f"{WEB_DOMAIN}/chat?slackChatId={chat_session.id}",
),
],
)
def _build_follow_up_block(message_id: int | None) -> ActionsBlock:
def build_follow_up_block(message_id: int | None) -> ActionsBlock:
return ActionsBlock(
block_id=build_feedback_id(message_id) if message_id is not None else None,
elements=[
@@ -564,77 +483,3 @@ def build_follow_up_resolved_blocks(
]
)
return [text_block, button_block]
def build_slack_response_blocks(
tenant_id: str | None,
message_info: SlackMessageInfo,
answer: OneShotQAResponse,
persona: Persona | None,
channel_conf: ChannelConfig | None,
use_citations: bool,
feedback_reminder_id: str | None,
skip_ai_feedback: bool = False,
) -> list[Block]:
"""
This function is a top level function that builds all the blocks for the Slack response.
It also handles combining all the blocks together.
"""
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(
message_info.thread_messages[-1].message, message_info.is_bot_msg
)
answer_blocks = _build_qa_response_blocks(
answer=answer,
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
)
web_follow_up_block = []
if channel_conf and channel_conf.get("show_continue_in_web_ui"):
web_follow_up_block.append(
_build_continue_in_web_ui_block(
tenant_id=tenant_id,
message_id=answer.chat_message_id,
)
)
follow_up_block = []
if channel_conf and channel_conf.get("follow_up_tags") is not None:
follow_up_block.append(
_build_follow_up_block(message_id=answer.chat_message_id)
)
ai_feedback_block = []
if answer.chat_message_id is not None and not skip_ai_feedback:
ai_feedback_block.append(
_build_qa_feedback_block(
message_id=answer.chat_message_id,
feedback_reminder_id=feedback_reminder_id,
)
)
citations_blocks = []
document_blocks = []
if use_citations:
# if citations are enabled, only show cited documents
citations_blocks = _build_citations_blocks(answer)
else:
document_blocks = _priority_ordered_documents_blocks(answer)
citations_divider = [DividerBlock()] if citations_blocks else []
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
all_blocks = (
restate_question_block
+ answer_blocks
+ ai_feedback_block
+ citations_divider
+ citations_blocks
+ document_blocks
+ buttons_divider
+ web_follow_up_block
+ follow_up_block
)
return all_blocks

View File

@@ -2,8 +2,8 @@ import os
from sqlalchemy.orm import Session
from danswer.db.models import SlackChannelConfig
from danswer.db.slack_channel_config import fetch_slack_channel_configs
from danswer.db.models import SlackBotConfig
from danswer.db.slack_bot_config import fetch_slack_bot_configs
VALID_SLACK_FILTERS = [
@@ -13,52 +13,46 @@ VALID_SLACK_FILTERS = [
]
def get_slack_channel_config_for_bot_and_channel(
db_session: Session,
slack_bot_id: int,
channel_name: str | None,
) -> SlackChannelConfig | None:
def get_slack_bot_config_for_channel(
channel_name: str | None, db_session: Session
) -> SlackBotConfig | None:
if not channel_name:
return None
slack_bot_configs = fetch_slack_channel_configs(
db_session=db_session, slack_bot_id=slack_bot_id
)
slack_bot_configs = fetch_slack_bot_configs(db_session=db_session)
for config in slack_bot_configs:
if channel_name in config.channel_config["channel_name"]:
if channel_name in config.channel_config["channel_names"]:
return config
return None
def validate_channel_name(
def validate_channel_names(
channel_names: list[str],
current_slack_bot_config_id: int | None,
db_session: Session,
current_slack_bot_id: int,
channel_name: str,
current_slack_channel_config_id: int | None,
) -> str:
"""Make sure that this channel_name does not exist in other Slack channel configs.
Returns a cleaned up channel name (e.g. '#' removed if present)"""
slack_bot_configs = fetch_slack_channel_configs(
db_session=db_session,
slack_bot_id=current_slack_bot_id,
)
cleaned_channel_name = channel_name.lstrip("#").lower()
for slack_channel_config in slack_bot_configs:
if slack_channel_config.id == current_slack_channel_config_id:
) -> list[str]:
"""Make sure that these channel_names don't exist in other slack bot configs.
Returns a list of cleaned up channel names (e.g. '#' removed if present)"""
slack_bot_configs = fetch_slack_bot_configs(db_session=db_session)
cleaned_channel_names = [
channel_name.lstrip("#").lower() for channel_name in channel_names
]
for slack_bot_config in slack_bot_configs:
if slack_bot_config.id == current_slack_bot_config_id:
continue
if cleaned_channel_name == slack_channel_config.channel_config["channel_name"]:
raise ValueError(
f"Channel name '{channel_name}' already exists in "
"another Slack channel config with in Slack Bot with name: "
f"{slack_channel_config.slack_bot.name}"
)
for channel_name in cleaned_channel_names:
if channel_name in slack_bot_config.channel_config["channel_names"]:
raise ValueError(
f"Channel name '{channel_name}' already exists in "
"another slack bot config"
)
return cleaned_channel_name
return cleaned_channel_names
# Scaling configurations for multi-tenant Slack channel handling
# Scaling configurations for multi-tenant Slack bot handling
TENANT_LOCK_EXPIRATION = 1800 # How long a pod can hold exclusive access to a tenant before other pods can acquire it
TENANT_HEARTBEAT_INTERVAL = (
15 # How often pods send heartbeats to indicate they are still processing a tenant

View File

@@ -2,7 +2,6 @@ from enum import Enum
LIKE_BLOCK_ACTION_ID = "feedback-like"
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui"
FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button"
FOLLOWUP_BUTTON_ACTION_ID = "followup-button"

View File

@@ -13,7 +13,7 @@ from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.danswerbot.slack.blocks import build_follow_up_resolved_blocks
from danswer.danswerbot.slack.blocks import get_document_feedback_blocks
from danswer.danswerbot.slack.config import get_slack_channel_config_for_bot_and_channel
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FeedbackVisibility
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
@@ -28,7 +28,7 @@ from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import decompose_action_id
from danswer.danswerbot.slack.utils import fetch_group_ids_from_names
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_feedback_visibility
from danswer.danswerbot.slack.utils import read_slack_thread
@@ -117,10 +117,8 @@ def handle_generate_answer_button(
)
with get_session_with_tenant(client.tenant_id) as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
channel_name=channel_name,
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
handle_regular_answer(
@@ -135,7 +133,7 @@ def handle_generate_answer_button(
is_bot_msg=False,
is_bot_dm=False,
),
slack_channel_config=slack_channel_config,
slack_bot_config=slack_bot_config,
receiver_ids=None,
client=client.web_client,
tenant_id=client.tenant_id,
@@ -258,16 +256,14 @@ def handle_followup_button(
channel_name, is_dm = get_channel_name_from_id(
client=client.web_client, channel_id=channel_id
)
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
channel_name=channel_name,
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
if slack_channel_config:
tag_names = slack_channel_config.channel_config.get("follow_up_tags")
if slack_bot_config:
tag_names = slack_bot_config.channel_config.get("follow_up_tags")
remaining = None
if tag_names:
tag_ids, remaining = fetch_slack_user_ids_from_emails(
tag_ids, remaining = fetch_user_ids_from_emails(
tag_names, client.web_client
)
if remaining:

View File

@@ -13,14 +13,14 @@ from danswer.danswerbot.slack.handlers.handle_standard_answers import (
handle_standard_answers,
)
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import slack_usage_report
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import SlackChannelConfig
from danswer.db.users import add_slack_user_if_not_exists
from danswer.db.models import SlackBotConfig
from danswer.db.users import add_non_web_user_if_not_exists
from danswer.utils.logger import setup_logger
from shared_configs.configs import SLACK_CHANNEL_ID
@@ -106,7 +106,7 @@ def remove_scheduled_feedback_reminder(
def handle_message(
message_info: SlackMessageInfo,
slack_channel_config: SlackChannelConfig | None,
slack_bot_config: SlackBotConfig | None,
client: WebClient,
feedback_reminder_id: str | None,
tenant_id: str | None,
@@ -140,7 +140,7 @@ def handle_message(
)
document_set_names: list[str] | None = None
persona = slack_channel_config.persona if slack_channel_config else None
persona = slack_bot_config.persona if slack_bot_config else None
prompt = None
if persona:
document_set_names = [
@@ -152,8 +152,8 @@ def handle_message(
respond_member_group_list = None
channel_conf = None
if slack_channel_config and slack_channel_config.channel_config:
channel_conf = slack_channel_config.channel_config
if slack_bot_config and slack_bot_config.channel_config:
channel_conf = slack_bot_config.channel_config
if not bypass_filters and "answer_filters" in channel_conf:
if (
"questionmark_prefilter" in channel_conf["answer_filters"]
@@ -184,7 +184,7 @@ def handle_message(
send_to: list[str] | None = None
missing_users: list[str] | None = None
if respond_member_group_list:
send_to, missing_ids = fetch_slack_user_ids_from_emails(
send_to, missing_ids = fetch_user_ids_from_emails(
respond_member_group_list, client
)
@@ -213,13 +213,13 @@ def handle_message(
with get_session_with_tenant(tenant_id) as db_session:
if message_info.email:
add_slack_user_if_not_exists(db_session, message_info.email)
add_non_web_user_if_not_exists(db_session, message_info.email)
# first check if we need to respond with a standard answer
used_standard_answer = handle_standard_answers(
message_info=message_info,
receiver_ids=send_to,
slack_channel_config=slack_channel_config,
slack_bot_config=slack_bot_config,
prompt=prompt,
logger=logger,
client=client,
@@ -231,7 +231,7 @@ def handle_message(
# if no standard answer applies, try a regular answer
issue_with_regular_answer = handle_regular_answer(
message_info=message_info,
slack_channel_config=slack_channel_config,
slack_bot_config=slack_bot_config,
receiver_ids=send_to,
client=client,
channel=channel,

View File

@@ -7,6 +7,7 @@ from typing import TypeVar
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import SectionBlock
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
@@ -20,11 +21,12 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.context.search.enums import OptionalSearchSetting
from danswer.context.search.models import BaseFilters
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.danswerbot.slack.blocks import build_slack_response_blocks
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
from danswer.danswerbot.slack.blocks import build_sources_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.formatting import format_slack_message
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread
@@ -32,8 +34,8 @@ from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import Persona
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.db.models import SlackChannelConfig
from danswer.db.persona import fetch_persona_by_id
from danswer.db.search_settings import get_current_search_settings
from danswer.db.users import get_user_by_email
@@ -46,6 +48,10 @@ from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.search.enums import OptionalSearchSetting
from danswer.search.models import BaseFilters
from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails
from danswer.utils.logger import DanswerLoggingAdapter
@@ -75,7 +81,7 @@ def rate_limits(
def handle_regular_answer(
message_info: SlackMessageInfo,
slack_channel_config: SlackChannelConfig | None,
slack_bot_config: SlackBotConfig | None,
receiver_ids: list[str] | None,
client: WebClient,
channel: str,
@@ -90,7 +96,7 @@ def handle_regular_answer(
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
) -> bool:
channel_conf = slack_channel_config.channel_config if slack_channel_config else None
channel_conf = slack_bot_config.channel_config if slack_bot_config else None
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
@@ -102,7 +108,7 @@ def handle_regular_answer(
user = get_user_by_email(message_info.email, db_session)
document_set_names: list[str] | None = None
persona = slack_channel_config.persona if slack_channel_config else None
persona = slack_bot_config.persona if slack_bot_config else None
prompt = None
if persona:
document_set_names = [
@@ -114,9 +120,9 @@ def handle_regular_answer(
bypass_acl = False
if (
slack_channel_config
and slack_channel_config.persona
and slack_channel_config.persona.document_sets
slack_bot_config
and slack_bot_config.persona
and slack_bot_config.persona.document_sets
):
# For Slack channels, use the full document set, admin will be warned when configuring it
# with non-public document sets
@@ -125,8 +131,8 @@ def handle_regular_answer(
# figure out if we want to use citations or quotes
use_citations = (
not DANSWER_BOT_USE_QUOTES
if slack_channel_config is None
else slack_channel_config.response_type == SlackBotResponseType.CITATIONS
if slack_bot_config is None
else slack_bot_config.response_type == SlackBotResponseType.CITATIONS
)
if not message_ts_to_respond_to and not is_bot_msg:
@@ -228,8 +234,8 @@ def handle_regular_answer(
# persona.llm_filter_extraction if persona is not None else True
# )
auto_detect_filters = (
slack_channel_config.enable_auto_filters
if slack_channel_config is not None
slack_bot_config.enable_auto_filters
if slack_bot_config is not None
else False
)
retrieval_details = RetrievalDetails(
@@ -405,16 +411,62 @@ def handle_regular_answer(
)
return True
all_blocks = build_slack_response_blocks(
tenant_id=tenant_id,
message_info=message_info,
answer=answer,
persona=persona,
channel_conf=channel_conf,
use_citations=use_citations,
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
answer_blocks = build_qa_response_blocks(
message_id=answer.chat_message_id,
answer=formatted_answer,
quotes=answer.quotes.quotes if answer.quotes else None,
source_filters=retrieval_info.applied_source_filters,
time_cutoff=retrieval_info.applied_time_cutoff,
favor_recent=retrieval_info.recency_bias_multiplier > 1,
# currently Personas don't support quotes
# if citations are enabled, also don't use quotes
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
feedback_reminder_id=feedback_reminder_id,
)
# Get the chunks fed to the LLM only, then fill with other docs
llm_doc_inds = answer.llm_selected_doc_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = []
citations_block = []
# if citations are enabled, only show cited documents
if use_citations:
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = build_sources_blocks(cited_documents=cited_docs)
elif priority_ordered_docs:
document_blocks = build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
document_blocks = [DividerBlock()] + document_blocks
all_blocks = (
restate_question_block + answer_blocks + citations_block + document_blocks
)
if channel_conf and channel_conf.get("follow_up_tags") is not None:
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
try:
respond_in_thread(
client=client,

View File

@@ -3,7 +3,7 @@ from sqlalchemy.orm import Session
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.db.models import Prompt
from danswer.db.models import SlackChannelConfig
from danswer.db.models import SlackBotConfig
from danswer.utils.logger import DanswerLoggingAdapter
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
@@ -14,7 +14,7 @@ logger = setup_logger()
def handle_standard_answers(
message_info: SlackMessageInfo,
receiver_ids: list[str] | None,
slack_channel_config: SlackChannelConfig | None,
slack_bot_config: SlackBotConfig | None,
prompt: Prompt | None,
logger: DanswerLoggingAdapter,
client: WebClient,
@@ -29,7 +29,7 @@ def handle_standard_answers(
return versioned_handle_standard_answers(
message_info=message_info,
receiver_ids=receiver_ids,
slack_channel_config=slack_channel_config,
slack_bot_config=slack_bot_config,
prompt=prompt,
logger=logger,
client=client,
@@ -40,7 +40,7 @@ def handle_standard_answers(
def _handle_standard_answers(
message_info: SlackMessageInfo,
receiver_ids: list[str] | None,
slack_channel_config: SlackChannelConfig | None,
slack_bot_config: SlackBotConfig | None,
prompt: Prompt | None,
logger: DanswerLoggingAdapter,
client: WebClient,

View File

@@ -4,7 +4,6 @@ import signal
import sys
import threading
import time
from collections.abc import Callable
from threading import Event
from types import FrameType
from typing import Any
@@ -17,7 +16,6 @@ from prometheus_client import start_http_server
from slack_sdk import WebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from danswer.configs.app_configs import POD_NAME
from danswer.configs.app_configs import POD_NAMESPACE
@@ -27,8 +25,7 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.context.search.retrieval.search_runner import download_nltk_data
from danswer.danswerbot.slack.config import get_slack_channel_config_for_bot_and_channel
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.config import MAX_TENANTS_PER_POD
from danswer.danswerbot.slack.config import TENANT_ACQUISITION_INTERVAL
from danswer.danswerbot.slack.config import TENANT_HEARTBEAT_EXPIRATION
@@ -57,25 +54,26 @@ from danswer.danswerbot.slack.handlers.handle_message import (
)
from danswer.danswerbot.slack.handlers.handle_message import schedule_feedback_reminder
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.danswerbot.slack.utils import check_message_limit
from danswer.danswerbot.slack.utils import decompose_action_id
from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_danswer_bot_slack_bot_id
from danswer.danswerbot.slack.utils import get_danswer_bot_app_id
from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import TenantSocketModeClient
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import SlackBot
from danswer.db.search_settings import get_current_search_settings
from danswer.db.slack_bot import fetch_slack_bots
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage
from danswer.redis.redis_pool import get_redis_client
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
@@ -84,8 +82,6 @@ from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -117,10 +113,8 @@ class SlackbotHandler:
def __init__(self) -> None:
logger.info("Initializing SlackbotHandler")
self.tenant_ids: Set[str | None] = set()
# The keys for these dictionaries are tuples of (tenant_id, slack_bot_id)
self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {}
self.socket_clients: Dict[str | None, TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[str | None, SlackBotTokens] = {}
self.running = True
self.pod_id = self.get_pod_id()
self._shutdown_event = Event()
@@ -175,52 +169,6 @@ class SlackbotHandler:
logger.exception(f"Error in heartbeat loop: {e}")
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
def _manage_clients_per_tenant(
self, db_session: Session, tenant_id: str | None, bot: SlackBot
) -> None:
slack_bot_tokens = SlackBotTokens(
bot_token=bot.bot_token,
app_token=bot.app_token,
)
tenant_bot_pair = (tenant_id, bot.id)
# If the tokens are not set, we need to close the socket client and delete the tokens
# for the tenant and app
if not slack_bot_tokens:
logger.debug(
f"No Slack bot token found for tenant {tenant_id}, bot {bot.id}"
)
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())
del self.socket_clients[tenant_bot_pair]
del self.slack_bot_tokens[tenant_bot_pair]
return
tokens_exist = tenant_bot_pair in self.slack_bot_tokens
tokens_changed = (
tokens_exist and slack_bot_tokens != self.slack_bot_tokens[tenant_bot_pair]
)
if not tokens_exist or tokens_changed:
if tokens_exist:
logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id}, bot {bot.id} - reconnecting"
)
else:
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
self.slack_bot_tokens[tenant_bot_pair] = slack_bot_tokens
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())
self.start_socket_client(bot.id, tenant_id, slack_bot_tokens)
def acquire_tenants(self) -> None:
tenant_ids = get_all_tenant_ids()
@@ -255,7 +203,6 @@ class SlackbotHandler:
continue
logger.debug(f"Acquired lock for tenant {tenant_id}")
self.tenant_ids.add(tenant_id)
for tenant_id in self.tenant_ids:
@@ -265,20 +212,57 @@ class SlackbotHandler:
try:
with get_session_with_tenant(tenant_id) as db_session:
try:
bots = fetch_slack_bots(db_session=db_session)
for bot in bots:
self._manage_clients_per_tenant(
db_session=db_session,
tenant_id=tenant_id,
bot=bot,
logger.debug(
f"Setting tenant ID context variable for tenant {tenant_id}"
)
slack_bot_tokens = fetch_tokens()
logger.debug(f"Fetched Slack bot tokens for tenant {tenant_id}")
logger.debug(
f"Reset tenant ID context variable for tenant {tenant_id}"
)
if not slack_bot_tokens:
logger.debug(
f"No Slack bot token found for tenant {tenant_id}"
)
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
continue
if (
tenant_id not in self.slack_bot_tokens
or slack_bot_tokens != self.slack_bot_tokens[tenant_id]
):
if tenant_id in self.slack_bot_tokens:
logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
search_settings = get_current_search_settings(
db_session
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
if self.socket_clients.get(tenant_id):
asyncio.run(self.socket_clients[tenant_id].close())
self.start_socket_client(tenant_id, slack_bot_tokens)
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if (tenant_id, bot.id) in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id, bot.id].close())
del self.socket_clients[tenant_id, bot.id]
del self.slack_bot_tokens[tenant_id, bot.id]
if self.socket_clients.get(tenant_id):
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
finally:
@@ -297,37 +281,26 @@ class SlackbotHandler:
)
def start_socket_client(
self, slack_bot_id: int, tenant_id: str | None, slack_bot_tokens: SlackBotTokens
self, tenant_id: str | None, slack_bot_tokens: SlackBotTokens
) -> None:
logger.info(
f"Starting socket client for tenant: {tenant_id}, app: {slack_bot_id}"
)
socket_client: TenantSocketModeClient = _get_socket_client(
slack_bot_tokens, tenant_id, slack_bot_id
)
logger.info(f"Starting socket client for tenant {tenant_id}")
socket_client = _get_socket_client(slack_bot_tokens, tenant_id)
# Append the event handler
process_slack_event = create_process_slack_event()
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.info(
f"Connecting socket client for tenant: {tenant_id}, app: {slack_bot_id}"
)
logger.info(f"Connecting socket client for tenant {tenant_id}")
socket_client.connect()
self.socket_clients[tenant_id, slack_bot_id] = socket_client
self.tenant_ids.add(tenant_id)
logger.info(
f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
)
self.socket_clients[tenant_id] = socket_client
logger.info(f"Started SocketModeClient for tenant {tenant_id}")
def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for (tenant_id, slack_bot_id), client in self.socket_clients.items():
asyncio.run(client.close())
logger.info(
f"Stopped SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
)
for tenant_id, client in self.socket_clients.items():
if client:
asyncio.run(client.close())
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
if not self.running:
@@ -411,7 +384,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
)
return False
bot_tag_id = get_danswer_bot_slack_bot_id(client.web_client)
bot_tag_id = get_danswer_bot_app_id(client.web_client)
if event_type == "message":
is_dm = event.get("channel_type") == "im"
is_tagged = bot_tag_id and bot_tag_id in msg
@@ -434,15 +407,13 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
)
with get_session_with_tenant(client.tenant_id) as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
channel_name=channel_name,
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
# If DanswerBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
if (not bot_tag_id or bot_tag_id not in msg) and (
not slack_channel_config
or not slack_channel_config.channel_config.get("respond_to_bots")
not slack_bot_config
or not slack_bot_config.channel_config.get("respond_to_bots")
):
channel_specific_logger.info("Ignoring message from bot")
return False
@@ -647,16 +618,14 @@ def process_message(
token = CURRENT_TENANT_ID_CONTEXTVAR.set(client.tenant_id)
try:
with get_session_with_tenant(client.tenant_id) as db_session:
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
db_session=db_session,
slack_bot_id=client.slack_bot_id,
channel_name=channel_name,
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
# Be careful about this default, don't want to accidentally spam every channel
# Users should be able to DM slack bot in their private channels though
if (
slack_channel_config is None
slack_bot_config is None
and not respond_every_channel
# Can't have configs for DMs so don't toss them out
and not is_dm
@@ -667,10 +636,9 @@ def process_message(
return
follow_up = bool(
slack_channel_config
and slack_channel_config.channel_config
and slack_channel_config.channel_config.get("follow_up_tags")
is not None
slack_bot_config
and slack_bot_config.channel_config
and slack_bot_config.channel_config.get("follow_up_tags") is not None
)
feedback_reminder_id = schedule_feedback_reminder(
details=details, client=client.web_client, include_followup=follow_up
@@ -678,7 +646,7 @@ def process_message(
failed = handle_message(
message_info=details,
slack_channel_config=slack_channel_config,
slack_bot_config=slack_bot_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
tenant_id=client.tenant_id,
@@ -730,32 +698,26 @@ def view_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> None
return process_feedback(req, client)
def create_process_slack_event() -> (
Callable[[TenantSocketModeClient, SocketModeRequest], None]
):
def process_slack_event(
client: TenantSocketModeClient, req: SocketModeRequest
) -> None:
# Always respond right away, if Slack doesn't receive these frequently enough
# it will assume the Bot is DEAD!!! :(
acknowledge_message(req, client)
def process_slack_event(client: TenantSocketModeClient, req: SocketModeRequest) -> None:
# Always respond right away, if Slack doesn't receive these frequently enough
# it will assume the Bot is DEAD!!! :(
acknowledge_message(req, client)
try:
if req.type == "interactive":
if req.payload.get("type") == "block_actions":
return action_routing(req, client)
elif req.payload.get("type") == "view_submission":
return view_routing(req, client)
elif req.type == "events_api" or req.type == "slash_commands":
return process_message(req, client)
except Exception:
logger.exception("Failed to process slack event")
return process_slack_event
try:
if req.type == "interactive":
if req.payload.get("type") == "block_actions":
return action_routing(req, client)
elif req.payload.get("type") == "view_submission":
return view_routing(req, client)
elif req.type == "events_api" or req.type == "slash_commands":
return process_message(req, client)
except Exception as e:
logger.exception(f"Failed to process slack event. Error: {e}")
logger.error(f"Slack request payload: {req.payload}")
def _get_socket_client(
slack_bot_tokens: SlackBotTokens, tenant_id: str | None, slack_bot_id: int
slack_bot_tokens: SlackBotTokens, tenant_id: str | None
) -> TenantSocketModeClient:
# For more info on how to set this up, checkout the docs:
# https://docs.danswer.dev/slack_bot_setup
@@ -764,7 +726,6 @@ def _get_socket_client(
app_token=slack_bot_tokens.app_token,
web_client=WebClient(token=slack_bot_tokens.bot_token),
tenant_id=tenant_id,
slack_bot_id=slack_bot_id,
)

View File

@@ -0,0 +1,28 @@
import os
from typing import cast
from danswer.configs.constants import KV_SLACK_BOT_TOKENS_CONFIG_KEY
from danswer.key_value_store.factory import get_kv_store
from danswer.server.manage.models import SlackBotTokens
def fetch_tokens() -> SlackBotTokens:
# first check env variables
app_token = os.environ.get("DANSWER_BOT_SLACK_APP_TOKEN")
bot_token = os.environ.get("DANSWER_BOT_SLACK_BOT_TOKEN")
if app_token and bot_token:
return SlackBotTokens(app_token=app_token, bot_token=bot_token)
dynamic_config_store = get_kv_store()
return SlackBotTokens(
**cast(dict, dynamic_config_store.load(key=KV_SLACK_BOT_TOKENS_CONFIG_KEY))
)
def save_tokens(
tokens: SlackBotTokens,
) -> None:
dynamic_config_store = get_kv_store()
dynamic_config_store.store(
key=KV_SLACK_BOT_TOKENS_CONFIG_KEY, val=dict(tokens), encrypt=True
)

View File

@@ -3,9 +3,9 @@ import random
import re
import string
import time
import uuid
from typing import Any
from typing import cast
from typing import Optional
from retry import retry
from slack_sdk import WebClient
@@ -30,6 +30,7 @@ from danswer.configs.danswerbot_configs import (
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.danswerbot.slack.constants import FeedbackVisibility
from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.db.engine import get_session_with_tenant
from danswer.db.users import get_user_by_email
from danswer.llm.exceptions import GenAIDisabledException
@@ -46,16 +47,16 @@ from danswer.utils.text_processing import replace_whitespaces_w_space
logger = setup_logger()
_DANSWER_BOT_SLACK_BOT_ID: str | None = None
_DANSWER_BOT_APP_ID: str | None = None
_DANSWER_BOT_MESSAGE_COUNT: int = 0
_DANSWER_BOT_COUNT_START_TIME: float = time.time()
def get_danswer_bot_slack_bot_id(web_client: WebClient) -> Any:
global _DANSWER_BOT_SLACK_BOT_ID
if _DANSWER_BOT_SLACK_BOT_ID is None:
_DANSWER_BOT_SLACK_BOT_ID = web_client.auth_test().get("user_id")
return _DANSWER_BOT_SLACK_BOT_ID
def get_danswer_bot_app_id(web_client: WebClient) -> Any:
global _DANSWER_BOT_APP_ID
if _DANSWER_BOT_APP_ID is None:
_DANSWER_BOT_APP_ID = web_client.auth_test().get("user_id")
return _DANSWER_BOT_APP_ID
def check_message_limit() -> bool:
@@ -136,10 +137,15 @@ def update_emote_react(
def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str:
bot_tag_id = get_danswer_bot_slack_bot_id(web_client=client)
bot_tag_id = get_danswer_bot_app_id(web_client=client)
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
def get_web_client() -> WebClient:
slack_tokens = fetch_tokens()
return WebClient(token=slack_tokens.bot_token)
@retry(
tries=DANSWER_BOT_NUM_RETRIES,
delay=0.25,
@@ -216,13 +222,6 @@ def build_feedback_id(
return unique_prefix + ID_SEPARATOR + feedback_id
def build_continue_in_web_ui_id(
message_id: int,
) -> str:
unique_prefix = str(uuid.uuid4())[:10]
return unique_prefix + ID_SEPARATOR + str(message_id)
def decompose_action_id(feedback_id: str) -> tuple[int, str | None, int | None]:
"""Decompose into query_id, document_id, document_rank, see above function"""
try:
@@ -320,7 +319,7 @@ def get_channel_name_from_id(
raise e
def fetch_slack_user_ids_from_emails(
def fetch_user_ids_from_emails(
user_emails: list[str], client: WebClient
) -> tuple[list[str], list[str]]:
user_ids: list[str] = []
@@ -438,9 +437,9 @@ def read_slack_thread(
)
message_type = MessageType.USER
else:
self_slack_bot_id = get_danswer_bot_slack_bot_id(client)
self_app_id = get_danswer_bot_app_id(client)
if reply.get("user") == self_slack_bot_id:
if reply.get("user") == self_app_id:
# DanswerBot response
message_type = MessageType.ASSISTANT
user_sem_id = "Assistant"
@@ -529,7 +528,7 @@ class SlackRateLimiter:
self.last_reset_time = time.time()
def notify(
self, client: WebClient, channel: str, position: int, thread_ts: str | None
self, client: WebClient, channel: str, position: int, thread_ts: Optional[str]
) -> None:
respond_in_thread(
client=client,
@@ -583,9 +582,6 @@ def get_feedback_visibility() -> FeedbackVisibility:
class TenantSocketModeClient(SocketModeClient):
def __init__(
self, tenant_id: str | None, slack_bot_id: int, *args: Any, **kwargs: Any
):
def __init__(self, tenant_id: str | None, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.tenant_id = tenant_id
self.slack_bot_id = slack_bot_id

View File

@@ -2,7 +2,6 @@ import uuid
from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
@@ -46,16 +45,14 @@ def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
]
async def fetch_user_for_api_key(
hashed_api_key: str, async_db_session: AsyncSession
) -> User | None:
"""NOTE: this is async, since it's used during auth
(which is necessarily async due to FastAPI Users)"""
return await async_db_session.scalar(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
def fetch_user_for_api_key(hashed_api_key: str, db_session: Session) -> User | None:
api_key = db_session.scalar(
select(ApiKey).where(ApiKey.hashed_api_key == hashed_api_key)
)
if api_key is None:
return None
return db_session.scalar(select(User).where(User.id == api_key.user_id)) # type: ignore
def get_api_key_fake_email(

View File

@@ -4,7 +4,6 @@ from typing import Any
from typing import Dict
from fastapi import Depends
from fastapi_users.models import ID
from fastapi_users.models import UP
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase
@@ -44,10 +43,7 @@ def get_total_users_count(db_session: Session) -> int:
"""
user_count = (
db_session.query(User)
.filter(
~User.email.endswith(get_api_key_email_pattern()), # type: ignore
User.role != UserRole.EXT_PERM_USER,
)
.filter(~User.email.endswith(get_api_key_email_pattern())) # type: ignore
.count()
)
invited_users = len(get_invited_users())
@@ -65,7 +61,7 @@ async def get_user_count() -> int:
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase):
async def create(
self,
create_dict: Dict[str, Any],

View File

@@ -3,7 +3,6 @@ from datetime import datetime
from datetime import timedelta
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
@@ -19,9 +18,6 @@ from danswer.auth.schemas import UserRole
from danswer.chat.models import DocumentRelevance
from danswer.configs.chat_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.context.search.models import RetrievalDocs
from danswer.context.search.models import SavedSearchDoc
from danswer.context.search.models import SearchDoc as ServerSearchDoc
from danswer.db.models import ChatMessage
from danswer.db.models import ChatMessage__SearchDoc
from danswer.db.models import ChatSession
@@ -31,11 +27,13 @@ from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.persona import get_best_persona_id_for_user
from danswer.db.pg_file_store import delete_lobj_by_name
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.search.models import RetrievalDocs
from danswer.search.models import SavedSearchDoc
from danswer.search.models import SearchDoc as ServerSearchDoc
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.utils.logger import setup_logger
@@ -252,50 +250,6 @@ def create_chat_session(
return chat_session
def duplicate_chat_session_for_user_from_slack(
db_session: Session,
user: User | None,
chat_session_id: UUID,
) -> ChatSession:
"""
This takes a chat session id for a session in Slack and:
- Creates a new chat session in the DB
- Tries to copy the persona from the original chat session
(if it is available to the user clicking the button)
- Sets the user to the given user (if provided)
"""
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
user_id=None, # Ignore user permissions for this
db_session=db_session,
)
if not chat_session:
raise HTTPException(status_code=400, detail="Invalid Chat Session ID provided")
# This enforces permissions and sets a default
new_persona_id = get_best_persona_id_for_user(
db_session=db_session,
user=user,
persona_id=chat_session.persona_id,
)
return create_chat_session(
db_session=db_session,
user_id=user.id if user else None,
persona_id=new_persona_id,
# Set this to empty string so the frontend will force a rename
description="",
llm_override=chat_session.llm_override,
prompt_override=chat_session.prompt_override,
# Chat sessions from Slack should put people in the chat UI, not the search
one_shot=False,
# Chat is in UI now so this is false
danswerbot_flow=False,
# Maybe we want this in the future to track if it was created from Slack
slack_thread_id=None,
)
def update_chat_session(
db_session: Session,
user_id: UUID | None,
@@ -382,28 +336,6 @@ def get_chat_message(
return chat_message
def get_chat_session_by_message_id(
db_session: Session,
message_id: int,
) -> ChatSession:
"""
Should only be used for Slack
Get the chat session associated with a specific message ID
Note: this ignores permission checks.
"""
stmt = select(ChatMessage).where(ChatMessage.id == message_id)
result = db_session.execute(stmt)
chat_message = result.scalar_one_or_none()
if chat_message is None:
raise ValueError(
f"Unable to find chat session associated with message ID: {message_id}"
)
return chat_message.chat_session
def get_chat_messages_by_sessions(
chat_session_ids: list[UUID],
user_id: UUID | None,
@@ -423,44 +355,6 @@ def get_chat_messages_by_sessions(
return db_session.execute(stmt).scalars().all()
def add_chats_to_session_from_slack_thread(
db_session: Session,
slack_chat_session_id: UUID,
new_chat_session_id: UUID,
) -> None:
new_root_message = get_or_create_root_message(
chat_session_id=new_chat_session_id,
db_session=db_session,
)
for chat_message in get_chat_messages_by_sessions(
chat_session_ids=[slack_chat_session_id],
user_id=None, # Ignore user permissions for this
db_session=db_session,
skip_permission_check=True,
):
if chat_message.message_type == MessageType.SYSTEM:
continue
# Duplicate the message
new_root_message = create_new_chat_message(
db_session=db_session,
chat_session_id=new_chat_session_id,
parent_message=new_root_message,
message=chat_message.message,
files=chat_message.files,
rephrased_query=chat_message.rephrased_query,
error=chat_message.error,
citations=chat_message.citations,
reference_docs=chat_message.search_docs,
tool_call=chat_message.tool_call,
prompt_id=chat_message.prompt_id,
token_count=chat_message.token_count,
message_type=chat_message.message_type,
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
)
def get_search_docs_for_chat_message(
chat_message_id: int, db_session: Session
) -> list[SearchDoc]:

View File

@@ -12,7 +12,6 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
from danswer.db.enums import IndexingMode
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
@@ -312,25 +311,3 @@ def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int)
# If this changes, we need to update this function.
cc_pair.last_time_external_group_sync = datetime.now(timezone.utc)
db_session.commit()
def mark_ccpair_with_indexing_trigger(
cc_pair_id: int, indexing_mode: IndexingMode | None, db_session: Session
) -> None:
"""indexing_mode sets a field which will be picked up by a background task
to trigger indexing. Set to None to disable the trigger."""
try:
cc_pair = db_session.execute(
select(ConnectorCredentialPair)
.where(ConnectorCredentialPair.id == cc_pair_id)
.with_for_update()
).scalar_one()
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
cc_pair.indexing_trigger = indexing_mode
db_session.commit()
except Exception:
db_session.rollback()
raise

View File

@@ -324,11 +324,8 @@ def associate_default_cc_pair(db_session: Session) -> None:
def _relate_groups_to_cc_pair__no_commit(
db_session: Session,
cc_pair_id: int,
user_group_ids: list[int] | None = None,
user_group_ids: list[int],
) -> None:
if not user_group_ids:
return
for group_id in user_group_ids:
db_session.add(
UserGroup__ConnectorCredentialPair(
@@ -405,11 +402,12 @@ def add_credential_to_connector(
db_session.flush() # make sure the association has an id
db_session.refresh(association)
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
user_group_ids=groups,
)
if groups and access_type != AccessType.SYNC:
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
user_group_ids=groups,
)
db_session.commit()

View File

@@ -209,7 +209,6 @@ def get_document_connector_counts(
def get_document_counts_for_cc_pairs(
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
) -> Sequence[tuple[int, int, int]]:
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
stmt = (
select(
DocumentByConnectorCredentialPair.connector_id,
@@ -309,7 +308,7 @@ def get_access_info_for_documents(
return db_session.execute(stmt).all() # type: ignore
def upsert_documents(
def _upsert_documents(
db_session: Session,
document_metadata_batch: list[DocumentMetadata],
initial_boost: int = DEFAULT_BOOST,
@@ -365,24 +364,24 @@ def upsert_documents(
db_session.commit()
def upsert_document_by_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, document_ids: list[str]
def _upsert_document_by_connector_credential_pair(
db_session: Session, document_metadata_batch: list[DocumentMetadata]
) -> None:
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
if not document_ids:
logger.info("`document_ids` is empty. Skipping.")
if not document_metadata_batch:
logger.info("`document_metadata_batch` is empty. Skipping.")
return
insert_stmt = insert(DocumentByConnectorCredentialPair).values(
[
model_to_dict(
DocumentByConnectorCredentialPair(
id=doc_id,
connector_id=connector_id,
credential_id=credential_id,
id=document_metadata.document_id,
connector_id=document_metadata.connector_id,
credential_id=document_metadata.credential_id,
)
)
for doc_id in document_ids
for document_metadata in document_metadata_batch
]
)
# for now, there are no columns to update. If more metadata is added, then this
@@ -443,6 +442,17 @@ def mark_document_as_synced(document_id: str, db_session: Session) -> None:
db_session.commit()
def upsert_documents_complete(
db_session: Session,
document_metadata_batch: list[DocumentMetadata],
) -> None:
_upsert_documents(db_session, document_metadata_batch)
_upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
logger.info(
f"Upserted {len(document_metadata_batch)} document store entries into DB"
)
def delete_document_by_connector_credential_pair__no_commit(
db_session: Session,
document_id: str,
@@ -551,7 +561,7 @@ def prepare_to_modify_documents(
db_session.commit() # ensure that we're not in a transaction
lock_acquired = False
for i in range(_NUM_LOCK_ATTEMPTS):
for _ in range(_NUM_LOCK_ATTEMPTS):
try:
with db_session.begin() as transaction:
lock_acquired = acquire_document_locks(
@@ -562,7 +572,7 @@ def prepare_to_modify_documents(
break
except OperationalError as e:
logger.warning(
f"Failed to acquire locks for documents on attempt {i}, retrying. Error: {e}"
f"Failed to acquire locks for documents, retrying. Error: {e}"
)
time.sleep(retry_delay)

View File

@@ -5,7 +5,6 @@ class IndexingStatus(str, PyEnum):
NOT_STARTED = "not_started"
IN_PROGRESS = "in_progress"
SUCCESS = "success"
CANCELED = "canceled"
FAILED = "failed"
COMPLETED_WITH_ERRORS = "completed_with_errors"
@@ -13,17 +12,11 @@ class IndexingStatus(str, PyEnum):
terminal_states = {
IndexingStatus.SUCCESS,
IndexingStatus.COMPLETED_WITH_ERRORS,
IndexingStatus.CANCELED,
IndexingStatus.FAILED,
}
return self in terminal_states
class IndexingMode(str, PyEnum):
UPDATE = "update"
REINDEX = "reindex"
# these may differ in the future, which is why we're okay with this duplication
class DeletionStatus(str, PyEnum):
NOT_STARTED = "not_started"

View File

@@ -67,13 +67,6 @@ def create_index_attempt(
return new_attempt.id
def delete_index_attempt(db_session: Session, index_attempt_id: int) -> None:
index_attempt = get_index_attempt(db_session, index_attempt_id)
if index_attempt:
db_session.delete(index_attempt)
db_session.commit()
def mock_successful_index_attempt(
connector_credential_pair_id: int,
search_settings_id: int,
@@ -225,28 +218,6 @@ def mark_attempt_partially_succeeded(
raise
def mark_attempt_canceled(
index_attempt_id: int,
db_session: Session,
reason: str = "Unknown",
) -> None:
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.with_for_update()
).scalar_one()
if not attempt.time_started:
attempt.time_started = datetime.now(timezone.utc)
attempt.status = IndexingStatus.CANCELED
attempt.error_msg = reason
db_session.commit()
except Exception:
db_session.rollback()
raise
def mark_attempt_failed(
index_attempt_id: int,
db_session: Session,

View File

@@ -42,7 +42,7 @@ from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.db.enums import AccessType, IndexingMode
from danswer.db.enums import AccessType
from danswer.configs.constants import NotificationType
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.constants import TokenRateLimitScope
@@ -53,11 +53,11 @@ from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.enums import TaskStatus
from danswer.db.pydantic_type import PydanticType
from danswer.utils.special_types import JSON_ro
from danswer.key_value_store.interface import JSON_ro
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.context.search.enums import RecencyBiasSetting
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.encryption import decrypt_bytes_to_string
from danswer.utils.encryption import encrypt_string_to_bytes
from danswer.utils.headers import HeaderItemDict
@@ -126,7 +126,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
chosen_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
)
@@ -172,6 +171,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
notifications: Mapped[list["Notification"]] = relationship(
"Notification", back_populates="user"
)
# Whether the user has logged in via web. False if user has only used Danswer through Slack bot
has_web_login: Mapped[bool] = mapped_column(Boolean, default=True)
cc_pairs: Mapped[list["ConnectorCredentialPair"]] = relationship(
"ConnectorCredentialPair",
back_populates="creator",
@@ -351,11 +352,11 @@ class StandardAnswer__StandardAnswerCategory(Base):
)
class SlackChannelConfig__StandardAnswerCategory(Base):
__tablename__ = "slack_channel_config__standard_answer_category"
class SlackBotConfig__StandardAnswerCategory(Base):
__tablename__ = "slack_bot_config__standard_answer_category"
slack_channel_config_id: Mapped[int] = mapped_column(
ForeignKey("slack_channel_config.id"), primary_key=True
slack_bot_config_id: Mapped[int] = mapped_column(
ForeignKey("slack_bot_config.id"), primary_key=True
)
standard_answer_category_id: Mapped[int] = mapped_column(
ForeignKey("standard_answer_category.id"), primary_key=True
@@ -439,10 +440,6 @@ class ConnectorCredentialPair(Base):
total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0)
indexing_trigger: Mapped[IndexingMode | None] = mapped_column(
Enum(IndexingMode, native_enum=False), nullable=True
)
connector: Mapped["Connector"] = relationship(
"Connector", back_populates="credentials"
)
@@ -1186,7 +1183,7 @@ class LLMProvider(Base):
default_model_name: Mapped[str] = mapped_column(String)
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
# Models to actually display to users
# Models to actually disp;aly to users
# If nulled out, we assume in the application logic we should present all
display_model_names: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
@@ -1368,9 +1365,6 @@ class Persona(Base):
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
Enum(RecencyBiasSetting, native_enum=False)
)
category_id: Mapped[int | None] = mapped_column(
ForeignKey("persona_category.id"), nullable=True
)
# Allows the Persona to specify a different LLM version than is controlled
# globablly via env variables. For flexibility, validity is not currently enforced
# NOTE: only is applied on the actual response generation - is not used for things like
@@ -1442,9 +1436,6 @@ class Persona(Base):
secondary="persona__user_group",
viewonly=True,
)
category: Mapped["PersonaCategory"] = relationship(
"PersonaCategory", back_populates="personas"
)
# Default personas loaded via yaml cannot have the same name
__table_args__ = (
@@ -1457,17 +1448,6 @@ class Persona(Base):
)
class PersonaCategory(Base):
__tablename__ = "persona_category"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
description: Mapped[str | None] = mapped_column(String, nullable=True)
personas: Mapped[list["Persona"]] = relationship(
"Persona", back_populates="category"
)
AllowedAnswerFilters = (
Literal["well_answered_postfilter"] | Literal["questionmark_prefilter"]
)
@@ -1477,7 +1457,7 @@ class ChannelConfig(TypedDict):
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column
in Postgres"""
channel_name: str
channel_names: list[str]
respond_tag_only: NotRequired[bool] # defaults to False
respond_to_bots: NotRequired[bool] # defaults to False
respond_member_group_list: NotRequired[list[str]]
@@ -1485,7 +1465,6 @@ class ChannelConfig(TypedDict):
# If None then no follow up
# If empty list, follow up with no tags
follow_up_tags: NotRequired[list[str]]
show_continue_in_web_ui: NotRequired[bool] # defaults to False
class SlackBotResponseType(str, PyEnum):
@@ -1493,11 +1472,10 @@ class SlackBotResponseType(str, PyEnum):
CITATIONS = "citations"
class SlackChannelConfig(Base):
__tablename__ = "slack_channel_config"
class SlackBotConfig(Base):
__tablename__ = "slack_bot_config"
id: Mapped[int] = mapped_column(primary_key=True)
slack_bot_id: Mapped[int] = mapped_column(ForeignKey("slack_bot.id"), nullable=True)
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
@@ -1514,30 +1492,10 @@ class SlackChannelConfig(Base):
)
persona: Mapped[Persona | None] = relationship("Persona")
slack_bot: Mapped["SlackBot"] = relationship(
"SlackBot",
back_populates="slack_channel_configs",
)
standard_answer_categories: Mapped[list["StandardAnswerCategory"]] = relationship(
"StandardAnswerCategory",
secondary=SlackChannelConfig__StandardAnswerCategory.__table__,
back_populates="slack_channel_configs",
)
class SlackBot(Base):
__tablename__ = "slack_bot"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
bot_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
app_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
slack_channel_configs: Mapped[list[SlackChannelConfig]] = relationship(
"SlackChannelConfig",
back_populates="slack_bot",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="slack_bot_configs",
)
@@ -1776,9 +1734,9 @@ class StandardAnswerCategory(Base):
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="categories",
)
slack_channel_configs: Mapped[list["SlackChannelConfig"]] = relationship(
"SlackChannelConfig",
secondary=SlackChannelConfig__StandardAnswerCategory.__table__,
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
"SlackBotConfig",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="standard_answer_categories",
)

View File

@@ -20,20 +20,19 @@ from danswer.auth.schemas import UserRole
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from danswer.context.search.enums import RecencyBiasSetting
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet
from danswer.db.models import Persona
from danswer.db.models import Persona__User
from danswer.db.models import Persona__UserGroup
from danswer.db.models import PersonaCategory
from danswer.db.models import Prompt
from danswer.db.models import StarterMessage
from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup
from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.utils.logger import setup_logger
@@ -113,31 +112,6 @@ def fetch_persona_by_id(
return persona
def get_best_persona_id_for_user(
db_session: Session, user: User | None, persona_id: int | None = None
) -> int | None:
if persona_id is not None:
stmt = select(Persona).where(Persona.id == persona_id).distinct()
stmt = _add_user_filters(
stmt=stmt,
user=user,
# We don't want to filter by editable here, we just want to see if the
# persona is usable by the user
get_editable=False,
)
persona = db_session.scalars(stmt).one_or_none()
if persona:
return persona.id
# If the persona is not found, or the slack bot is using doc sets instead of personas,
# we need to find the best persona for the user
# This is the persona with the highest display priority that the user has access to
stmt = select(Persona).order_by(Persona.display_priority.desc()).distinct()
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=True)
persona = db_session.scalars(stmt).one_or_none()
return persona.id if persona else None
def _get_persona_by_name(
persona_name: str, user: User | None, db_session: Session
) -> Persona | None:
@@ -185,7 +159,7 @@ def create_update_persona(
"persona_id": persona_id,
"user": user,
"db_session": db_session,
**create_persona_request.model_dump(exclude={"users", "groups"}),
**create_persona_request.dict(exclude={"users", "groups"}),
}
persona = upsert_persona(**persona_data)
@@ -284,6 +258,7 @@ def get_personas(
) -> Sequence[Persona]:
stmt = select(Persona).distinct()
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
if not include_default:
stmt = stmt.where(Persona.builtin_persona.is_(False))
if not include_slack_bot_personas:
@@ -415,9 +390,6 @@ def upsert_prompt(
return prompt
# NOTE: This operation cannot update persona configuration options that
# are core to the persona, such as its display priority and
# whether or not the assistant is a built-in / default assistant
def upsert_persona(
user: User | None,
name: str,
@@ -445,7 +417,6 @@ def upsert_persona(
search_start_date: datetime | None = None,
builtin_persona: bool = False,
is_default_persona: bool = False,
category_id: int | None = None,
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
) -> Persona:
@@ -486,7 +457,7 @@ def upsert_persona(
validate_persona_tools(tools)
if persona:
if persona.builtin_persona and not builtin_persona:
if not builtin_persona and persona.builtin_persona:
raise ValueError("Cannot update builtin persona with non-builtin.")
# this checks if the user has permission to edit the persona
@@ -502,6 +473,7 @@ def upsert_persona(
persona.llm_relevance_filter = llm_relevance_filter
persona.llm_filter_extraction = llm_filter_extraction
persona.recency_bias = recency_bias
persona.builtin_persona = builtin_persona
persona.llm_model_provider_override = llm_model_provider_override
persona.llm_model_version_override = llm_model_version_override
persona.starter_messages = starter_messages
@@ -511,9 +483,11 @@ def upsert_persona(
persona.icon_shape = icon_shape
if remove_image or uploaded_image_id:
persona.uploaded_image_id = uploaded_image_id
persona.display_priority = display_priority
persona.is_visible = is_visible
persona.search_start_date = search_start_date
persona.category_id = category_id
persona.is_default_persona = is_default_persona
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
@@ -554,7 +528,6 @@ def upsert_persona(
is_visible=is_visible,
search_start_date=search_start_date,
is_default_persona=is_default_persona,
category_id=category_id,
)
db_session.add(persona)
@@ -758,8 +731,6 @@ def get_prompt_by_name(
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(Prompt.user_id == user.id)
# Order by ID to ensure consistent result when multiple prompts exist
stmt = stmt.order_by(Prompt.id).limit(1)
result = db_session.execute(stmt).scalar_one_or_none()
return result
@@ -773,39 +744,3 @@ def delete_persona_by_name(
db_session.execute(stmt)
db_session.commit()
def get_assistant_categories(db_session: Session) -> list[PersonaCategory]:
return db_session.query(PersonaCategory).all()
def create_assistant_category(
db_session: Session, name: str, description: str
) -> PersonaCategory:
category = PersonaCategory(name=name, description=description)
db_session.add(category)
db_session.commit()
return category
def update_persona_category(
category_id: int,
category_description: str,
category_name: str,
db_session: Session,
) -> None:
persona_category = (
db_session.query(PersonaCategory)
.filter(PersonaCategory.id == category_id)
.one_or_none()
)
if persona_category is None:
raise ValueError(f"Persona category with ID {category_id} does not exist")
persona_category.description = category_description
persona_category.name = category_name
db_session.commit()
def delete_persona_category(category_id: int, db_session: Session) -> None:
db_session.query(PersonaCategory).filter(PersonaCategory.id == category_id).delete()
db_session.commit()

View File

@@ -12,7 +12,6 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.context.search.models import SavedSearchSettings
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
@@ -22,6 +21,7 @@ from danswer.db.models import SearchSettings
from danswer.indexing.models import IndexingSetting
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from danswer.search.models import SavedSearchSettings
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)
@@ -143,25 +143,6 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
return latest_settings
def get_active_search_settings(db_session: Session) -> list[SearchSettings]:
"""Returns active search settings. The first entry will always be the current search
settings. If there are new search settings that are being migrated to, those will be
the second entry."""
search_settings_list: list[SearchSettings] = []
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings_list.append(primary_search_settings)
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings_list.append(secondary_search_settings)
return search_settings_list
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
query = select(SearchSettings).order_by(SearchSettings.id.desc())
result = db_session.execute(query)

View File

@@ -1,76 +0,0 @@
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import SlackBot
def insert_slack_bot(
db_session: Session,
name: str,
enabled: bool,
bot_token: str,
app_token: str,
) -> SlackBot:
slack_bot = SlackBot(
name=name,
enabled=enabled,
bot_token=bot_token,
app_token=app_token,
)
db_session.add(slack_bot)
db_session.commit()
return slack_bot
def update_slack_bot(
db_session: Session,
slack_bot_id: int,
name: str,
enabled: bool,
bot_token: str,
app_token: str,
) -> SlackBot:
slack_bot = db_session.scalar(select(SlackBot).where(SlackBot.id == slack_bot_id))
if slack_bot is None:
raise ValueError(f"Unable to find Slack Bot with ID {slack_bot_id}")
# update the app
slack_bot.name = name
slack_bot.enabled = enabled
slack_bot.bot_token = bot_token
slack_bot.app_token = app_token
db_session.commit()
return slack_bot
def fetch_slack_bot(
db_session: Session,
slack_bot_id: int,
) -> SlackBot:
slack_bot = db_session.scalar(select(SlackBot).where(SlackBot.id == slack_bot_id))
if slack_bot is None:
raise ValueError(f"Unable to find Slack Bot with ID {slack_bot_id}")
return slack_bot
def remove_slack_bot(
db_session: Session,
slack_bot_id: int,
) -> None:
slack_bot = fetch_slack_bot(
db_session=db_session,
slack_bot_id=slack_bot_id,
)
db_session.delete(slack_bot)
db_session.commit()
def fetch_slack_bots(db_session: Session) -> Sequence[SlackBot]:
return db_session.scalars(select(SlackBot)).all()

View File

@@ -5,25 +5,25 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.context.search.enums import RecencyBiasSetting
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
from danswer.db.models import Persona__DocumentSet
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.db.models import SlackChannelConfig
from danswer.db.models import User
from danswer.db.persona import get_default_prompt
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import upsert_persona
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.errors import EERequiredError
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
def _build_persona_name(channel_name: str) -> str:
return f"{SLACK_BOT_PERSONA_PREFIX}{channel_name}"
def _build_persona_name(channel_names: list[str]) -> str:
return f"{SLACK_BOT_PERSONA_PREFIX}{'-'.join(channel_names)}"
def _cleanup_relationships(db_session: Session, persona_id: int) -> None:
@@ -38,9 +38,9 @@ def _cleanup_relationships(db_session: Session, persona_id: int) -> None:
db_session.delete(rel)
def create_slack_channel_persona(
def create_slack_bot_persona(
db_session: Session,
channel_name: str,
channel_names: list[str],
document_set_ids: list[int],
existing_persona_id: int | None = None,
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
@@ -48,11 +48,11 @@ def create_slack_channel_persona(
) -> Persona:
"""NOTE: does not commit changes"""
# create/update persona associated with the Slack channel
persona_name = _build_persona_name(channel_name)
# create/update persona associated with the slack bot
persona_name = _build_persona_name(channel_names)
default_prompt = get_default_prompt(db_session)
persona = upsert_persona(
user=None, # Slack channel Personas are not attached to users
user=None, # Slack Bot Personas are not attached to users
persona_id=existing_persona_id,
name=persona_name,
description="",
@@ -78,15 +78,14 @@ def _no_ee_standard_answer_categories(*args: Any, **kwargs: Any) -> list:
return []
def insert_slack_channel_config(
db_session: Session,
slack_bot_id: int,
def insert_slack_bot_config(
persona_id: int | None,
channel_config: ChannelConfig,
response_type: SlackBotResponseType,
standard_answer_category_ids: list[int],
enable_auto_filters: bool,
) -> SlackChannelConfig:
db_session: Session,
) -> SlackBotConfig:
versioned_fetch_standard_answer_categories_by_ids = (
fetch_versioned_implementation_with_fallback(
"danswer.db.standard_answer",
@@ -111,37 +110,34 @@ def insert_slack_channel_config(
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
)
slack_channel_config = SlackChannelConfig(
slack_bot_id=slack_bot_id,
slack_bot_config = SlackBotConfig(
persona_id=persona_id,
channel_config=channel_config,
response_type=response_type,
standard_answer_categories=existing_standard_answer_categories,
enable_auto_filters=enable_auto_filters,
)
db_session.add(slack_channel_config)
db_session.add(slack_bot_config)
db_session.commit()
return slack_channel_config
return slack_bot_config
def update_slack_channel_config(
db_session: Session,
slack_channel_config_id: int,
def update_slack_bot_config(
slack_bot_config_id: int,
persona_id: int | None,
channel_config: ChannelConfig,
response_type: SlackBotResponseType,
standard_answer_category_ids: list[int],
enable_auto_filters: bool,
) -> SlackChannelConfig:
slack_channel_config = db_session.scalar(
select(SlackChannelConfig).where(
SlackChannelConfig.id == slack_channel_config_id
)
db_session: Session,
) -> SlackBotConfig:
slack_bot_config = db_session.scalar(
select(SlackBotConfig).where(SlackBotConfig.id == slack_bot_config_id)
)
if slack_channel_config is None:
if slack_bot_config is None:
raise ValueError(
f"Unable to find Slack channel config with ID {slack_channel_config_id}"
f"Unable to find slack bot config with ID {slack_bot_config_id}"
)
versioned_fetch_standard_answer_categories_by_ids = (
@@ -163,25 +159,25 @@ def update_slack_channel_config(
)
# get the existing persona id before updating the object
existing_persona_id = slack_channel_config.persona_id
existing_persona_id = slack_bot_config.persona_id
# update the config
# NOTE: need to do this before cleaning up the old persona or else we
# will encounter `violates foreign key constraint` errors
slack_channel_config.persona_id = persona_id
slack_channel_config.channel_config = channel_config
slack_channel_config.response_type = response_type
slack_channel_config.standard_answer_categories = list(
slack_bot_config.persona_id = persona_id
slack_bot_config.channel_config = channel_config
slack_bot_config.response_type = response_type
slack_bot_config.standard_answer_categories = list(
existing_standard_answer_categories
)
slack_channel_config.enable_auto_filters = enable_auto_filters
slack_bot_config.enable_auto_filters = enable_auto_filters
# if the persona has changed, then clean up the old persona
if persona_id != existing_persona_id and existing_persona_id:
existing_persona = db_session.scalar(
select(Persona).where(Persona.id == existing_persona_id)
)
# if the existing persona was one created just for use with this Slack channel,
# if the existing persona was one created just for use with this Slack Bot,
# then clean it up
if existing_persona and existing_persona.name.startswith(
SLACK_BOT_PERSONA_PREFIX
@@ -192,30 +188,28 @@ def update_slack_channel_config(
db_session.commit()
return slack_channel_config
return slack_bot_config
def remove_slack_channel_config(
db_session: Session,
slack_channel_config_id: int,
def remove_slack_bot_config(
slack_bot_config_id: int,
user: User | None,
db_session: Session,
) -> None:
slack_channel_config = db_session.scalar(
select(SlackChannelConfig).where(
SlackChannelConfig.id == slack_channel_config_id
)
slack_bot_config = db_session.scalar(
select(SlackBotConfig).where(SlackBotConfig.id == slack_bot_config_id)
)
if slack_channel_config is None:
if slack_bot_config is None:
raise ValueError(
f"Unable to find Slack channel config with ID {slack_channel_config_id}"
f"Unable to find slack bot config with ID {slack_bot_config_id}"
)
existing_persona_id = slack_channel_config.persona_id
existing_persona_id = slack_bot_config.persona_id
if existing_persona_id:
existing_persona = db_session.scalar(
select(Persona).where(Persona.id == existing_persona_id)
)
# if the existing persona was one created just for use with this Slack channel,
# if the existing persona was one created just for use with this Slack Bot,
# then clean it up
if existing_persona and existing_persona.name.startswith(
SLACK_BOT_PERSONA_PREFIX
@@ -227,28 +221,17 @@ def remove_slack_channel_config(
persona_id=existing_persona_id, user=user, db_session=db_session
)
db_session.delete(slack_channel_config)
db_session.delete(slack_bot_config)
db_session.commit()
def fetch_slack_channel_configs(
db_session: Session, slack_bot_id: int | None = None
) -> Sequence[SlackChannelConfig]:
if not slack_bot_id:
return db_session.scalars(select(SlackChannelConfig)).all()
return db_session.scalars(
select(SlackChannelConfig).where(
SlackChannelConfig.slack_bot_id == slack_bot_id
)
).all()
def fetch_slack_channel_config(
db_session: Session, slack_channel_config_id: int
) -> SlackChannelConfig | None:
def fetch_slack_bot_config(
db_session: Session, slack_bot_config_id: int
) -> SlackBotConfig | None:
return db_session.scalar(
select(SlackChannelConfig).where(
SlackChannelConfig.id == slack_channel_config_id
)
select(SlackBotConfig).where(SlackBotConfig.id == slack_bot_config_id)
)
def fetch_slack_bot_configs(db_session: Session) -> Sequence[SlackBotConfig]:
return db_session.scalars(select(SlackBotConfig)).all()

View File

@@ -1,7 +1,6 @@
from collections.abc import Sequence
from uuid import UUID
from fastapi import HTTPException
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy import select
@@ -11,98 +10,30 @@ from danswer.auth.schemas import UserRole
from danswer.db.models import User
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
"""
Validate that a user role update is valid.
Assumed only admins can hit this endpoint.
raise if:
- requested role is a curator
- requested role is a slack user
- requested role is an external permissioned user
- requested role is a limited user
- current role is a slack user
- current role is an external permissioned user
- current role is a limited user
"""
if current_role == UserRole.SLACK_USER:
raise HTTPException(
status_code=400,
detail="To change a Slack User's role, they must first login to Danswer via the web app.",
)
if current_role == UserRole.EXT_PERM_USER:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail="To change an External Permissioned User's role, they must first login to Danswer via the web app.",
)
if current_role == UserRole.LIMITED:
raise HTTPException(
status_code=400,
detail="To change a Limited User's role, they must first login to Danswer via the web app.",
)
if requested_role == UserRole.CURATOR:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail="Curator role must be set via the User Group Menu",
)
if requested_role == UserRole.LIMITED:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail=(
"A user cannot be set to a Limited User role. "
"This role is automatically assigned to users through certain endpoints in the API."
),
)
if requested_role == UserRole.SLACK_USER:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail=(
"A user cannot be set to a Slack User role. "
"This role is automatically assigned to users who only use Danswer via Slack."
),
)
if requested_role == UserRole.EXT_PERM_USER:
# This shouldn't happen, but just in case
raise HTTPException(
status_code=400,
detail=(
"A user cannot be set to an External Permissioned User role. "
"This role is automatically assigned to users who have been "
"pulled in to the system via an external permissions system."
),
)
def list_users(
db_session: Session, email_filter_string: str = "", include_external: bool = False
db_session: Session, email_filter_string: str = "", user: User | None = None
) -> Sequence[User]:
"""List all users. No pagination as of now, as the # of users
is assumed to be relatively small (<< 1 million)"""
stmt = select(User)
where_clause = []
if not include_external:
where_clause.append(User.role != UserRole.EXT_PERM_USER)
if email_filter_string:
where_clause.append(User.email.ilike(f"%{email_filter_string}%")) # type: ignore
stmt = stmt.where(*where_clause)
stmt = stmt.where(User.email.ilike(f"%{email_filter_string}%")) # type: ignore
return db_session.scalars(stmt).unique().all()
def get_users_by_emails(
db_session: Session, emails: list[str]
) -> tuple[list[User], list[str]]:
# Use distinct to avoid duplicates
stmt = select(User).filter(User.email.in_(emails)) # type: ignore
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
found_users_emails = [user.email for user in found_users]
missing_user_emails = [email for email in emails if email not in found_users_emails]
return found_users, missing_user_emails
def get_user_by_email(email: str, db_session: Session) -> User | None:
user = (
db_session.query(User)
@@ -114,72 +45,68 @@ def get_user_by_email(email: str, db_session: Session) -> User | None:
def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None:
return db_session.query(User).filter(User.id == user_id).first() # type: ignore
user = db_session.query(User).filter(User.id == user_id).first() # type: ignore
return user
def _generate_slack_user(email: str) -> User:
def _generate_non_web_user(email: str) -> User:
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
return User(
email=email,
hashed_password=hashed_pass,
role=UserRole.SLACK_USER,
has_web_login=False,
role=UserRole.BASIC,
)
def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
email = email.lower()
def add_non_web_user_if_not_exists(db_session: Session, email: str) -> User:
user = get_user_by_email(email, db_session)
if user is not None:
# If the user is an external permissioned user, we update it to a slack user
if user.role == UserRole.EXT_PERM_USER:
user.role = UserRole.SLACK_USER
db_session.commit()
return user
user = _generate_slack_user(email=email)
user = _generate_non_web_user(email=email)
db_session.add(user)
db_session.commit()
return user
def _get_users_by_emails(
db_session: Session, lower_emails: list[str]
) -> tuple[list[User], list[str]]:
stmt = select(User).filter(func.lower(User.email).in_(lower_emails)) # type: ignore
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
def add_non_web_user_if_not_exists__no_commit(db_session: Session, email: str) -> User:
user = get_user_by_email(email, db_session)
if user is not None:
return user
# Extract found emails and convert to lowercase to avoid case sensitivity issues
found_users_emails = [user.email.lower() for user in found_users]
# Separate emails for users that were not found
missing_user_emails = [
email for email in lower_emails if email not in found_users_emails
]
return found_users, missing_user_emails
user = _generate_non_web_user(email=email)
db_session.add(user)
db_session.flush() # generate id
return user
def _generate_ext_permissioned_user(email: str) -> User:
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
return User(
email=email,
hashed_password=hashed_pass,
role=UserRole.EXT_PERM_USER,
)
def batch_add_ext_perm_user_if_not_exists(
def batch_add_non_web_user_if_not_exists__no_commit(
db_session: Session, emails: list[str]
) -> list[User]:
lower_emails = [email.lower() for email in emails]
found_users, missing_lower_emails = _get_users_by_emails(db_session, lower_emails)
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
new_users: list[User] = []
for email in missing_lower_emails:
new_users.append(_generate_ext_permissioned_user(email=email))
for email in missing_user_emails:
new_users.append(_generate_non_web_user(email=email))
db_session.add_all(new_users)
db_session.flush() # generate ids
return found_users + new_users
def batch_add_non_web_user_if_not_exists(
db_session: Session, emails: list[str]
) -> list[User]:
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
new_users: list[User] = []
for email in missing_user_emails:
new_users.append(_generate_non_web_user(email=email))
db_session.add_all(new_users)
db_session.commit()

View File

@@ -3,10 +3,10 @@ import uuid
from sqlalchemy.orm import Session
from danswer.context.search.models import InferenceChunk
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.indexing.models import IndexChunk
from danswer.search.models import InferenceChunk
DEFAULT_BATCH_SIZE = 30

View File

@@ -4,9 +4,9 @@ from datetime import datetime
from typing import Any
from danswer.access.models import DocumentAccess
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceChunkUncleaned
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned
from shared_configs.model_server_models import Embedding

View File

@@ -15,7 +15,7 @@ schema DANSWER_CHUNK_NAME {
# Must have an additional field for whether to skip title embeddings
# This information cannot be extracted from either the title field nor title embedding
field skip_title type bool {
indexing: attribute
indexing: attribute
}
# May not always match the `semantic_identifier` e.g. for Slack docs the
# `semantic_identifier` will be the channel name, but the `title` will be empty
@@ -36,7 +36,7 @@ schema DANSWER_CHUNK_NAME {
}
# Title embedding (x1)
field title_embedding type tensor<float>(x[VARIABLE_DIM]) {
indexing: attribute | index
indexing: attribute
attribute {
distance-metric: angular
}
@@ -44,7 +44,7 @@ schema DANSWER_CHUNK_NAME {
# Content embeddings (chunk + optional mini chunks embeddings)
# "t" and "x" are arbitrary names, not special keywords
field embeddings type tensor<float>(t{},x[VARIABLE_DIM]) {
indexing: attribute | index
indexing: attribute
attribute {
distance-metric: angular
}

View File

@@ -11,8 +11,6 @@ import httpx
from retry import retry
from danswer.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceChunkUncleaned
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.document_index.vespa.shared_utils.utils import get_vespa_http_client
from danswer.document_index.vespa.shared_utils.vespa_request_builders import (
@@ -46,6 +44,8 @@ from danswer.document_index.vespa_constants import SOURCE_LINKS
from danswer.document_index.vespa_constants import SOURCE_TYPE
from danswer.document_index.vespa_constants import TITLE
from danswer.document_index.vespa_constants import YQL_BASE
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel

View File

@@ -22,8 +22,6 @@ from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
from danswer.configs.chat_configs import VESPA_SEARCHER_THREADS
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceChunkUncleaned
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentInsertionRecord
from danswer.document_index.interfaces import UpdateRequest
@@ -70,6 +68,8 @@ from danswer.document_index.vespa_constants import VESPA_TIMEOUT
from danswer.document_index.vespa_constants import YQL_BASE
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.key_value_store.factory import get_kv_store
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT

View File

@@ -2,7 +2,6 @@ import concurrent.futures
import json
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
import httpx
from retry import retry
@@ -195,14 +194,6 @@ def _index_vespa_chunk(
logger.exception(
f"Failed to index document: '{document.id}'. Got response: '{res.text}'"
)
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE:
logger.error(
"NOTE: HTTP Status 507 Insufficient Storage usually means "
"you need to allocate more memory or disk space to the "
"Vespa/index container."
)
raise e

View File

@@ -3,7 +3,6 @@ from datetime import timedelta
from datetime import timezone
from danswer.configs.constants import INDEX_SEPARATOR
from danswer.context.search.models import IndexFilters
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.document_index.vespa_constants import ACCESS_CONTROL_LIST
from danswer.document_index.vespa_constants import CHUNK_ID
@@ -14,6 +13,7 @@ from danswer.document_index.vespa_constants import HIDDEN
from danswer.document_index.vespa_constants import METADATA_LIST
from danswer.document_index.vespa_constants import SOURCE_TYPE
from danswer.document_index.vespa_constants import TENANT_ID
from danswer.search.models import IndexFilters
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -295,7 +295,7 @@ def pptx_to_text(file: IO[Any]) -> str:
def xlsx_to_text(file: IO[Any]) -> str:
workbook = openpyxl.load_workbook(file, read_only=True)
workbook = openpyxl.load_workbook(file)
text_content = []
for sheet in workbook.worksheets:
sheet_string = "\n".join(

View File

@@ -59,12 +59,6 @@ class FileStore(ABC):
Contents of the file and metadata dict
"""
@abstractmethod
def read_file_record(self, file_name: str) -> PGFileStore:
"""
Read the file record by the name
"""
@abstractmethod
def delete_file(self, file_name: str) -> None:
"""

View File

@@ -10,11 +10,10 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_text
from danswer.utils.text_processing import shared_precompare_cleanup
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
@@ -126,7 +125,7 @@ class Chunker:
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
mini_chunk_size: int = MINI_CHUNK_SIZE,
callback: IndexingHeartbeatInterface | None = None,
heartbeat: Heartbeat | None = None,
) -> None:
from llama_index.text_splitter import SentenceSplitter
@@ -135,7 +134,7 @@ class Chunker:
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.tokenizer = tokenizer
self.callback = callback
self.heartbeat = heartbeat
self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
@@ -221,20 +220,9 @@ class Chunker:
mini_chunk_texts=self._get_mini_chunk_texts(text),
)
for section_idx, section in enumerate(document.sections):
section_text = clean_text(section.text)
for section in document.sections:
section_text = section.text
section_link_text = section.link or ""
# If there is no useful content, not even the title, just drop it
if not section_text and (not document.title or section_idx > 0):
# If a section is empty and the document has no title, we can just drop it. We return a list of
# DocAwareChunks where each one contains the necessary information needed down the line for indexing.
# There is no concern about dropping whole documents from this list, it should not cause any indexing failures.
logger.warning(
f"Skipping section {section.text} from document "
f"{document.semantic_identifier} due to empty text after cleaning "
f" with link {section_link_text}"
)
continue
section_token_count = len(self.tokenizer.tokenize(section_text))
@@ -250,26 +238,31 @@ class Chunker:
split_texts = self.chunk_splitter.split_text(section_text)
for i, split_text in enumerate(split_texts):
if (
STRICT_CHUNK_TOKEN_LIMIT
and
# Tokenizer only runs if STRICT_CHUNK_TOKEN_LIMIT is true
len(self.tokenizer.tokenize(split_text)) > content_token_limit
):
# If STRICT_CHUNK_TOKEN_LIMIT is true, manually check
# the token count of each split text to ensure it is
# not larger than the content_token_limit
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
)
for i, small_chunk in enumerate(smaller_chunks):
split_token_count = len(self.tokenizer.tokenize(split_text))
if STRICT_CHUNK_TOKEN_LIMIT:
split_token_count = len(self.tokenizer.tokenize(split_text))
if split_token_count > content_token_limit:
# Further split the oversized chunk
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
)
for i, small_chunk in enumerate(smaller_chunks):
chunks.append(
_create_chunk(
text=small_chunk,
links={0: section_link_text},
is_continuation=(i != 0),
)
)
else:
chunks.append(
_create_chunk(
text=small_chunk,
text=split_text,
links={0: section_link_text},
is_continuation=(i != 0),
)
)
else:
chunks.append(
_create_chunk(
@@ -361,20 +354,11 @@ class Chunker:
return normal_chunks
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
"""
Takes in a list of documents and chunks them into smaller chunks for indexing
while persisting the document metadata.
"""
final_chunks: list[DocAwareChunk] = []
for document in documents:
if self.callback:
if self.callback.should_stop():
raise RuntimeError("Chunker.chunk: Stop signal detected")
final_chunks.extend(self._handle_single_document(document))
chunks = self._handle_single_document(document)
final_chunks.extend(chunks)
if self.callback:
self.callback.progress("Chunker.chunk", len(chunks))
if self.heartbeat:
self.heartbeat.heartbeat()
return final_chunks

View File

@@ -2,7 +2,7 @@ from abc import ABC
from abc import abstractmethod
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
@@ -34,7 +34,7 @@ class IndexingEmbedder(ABC):
api_url: str | None,
api_version: str | None,
deployment_name: str | None,
callback: IndexingHeartbeatInterface | None,
heartbeat: Heartbeat | None,
):
self.model_name = model_name
self.normalize = normalize
@@ -60,7 +60,7 @@ class IndexingEmbedder(ABC):
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
callback=callback,
heartbeat=heartbeat,
)
@abstractmethod
@@ -83,7 +83,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
callback: IndexingHeartbeatInterface | None = None,
heartbeat: Heartbeat | None = None,
):
super().__init__(
model_name,
@@ -95,7 +95,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url,
api_version,
deployment_name,
callback,
heartbeat,
)
@log_function_time()
@@ -201,9 +201,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
@classmethod
def from_db_search_settings(
cls,
search_settings: SearchSettings,
callback: IndexingHeartbeatInterface | None = None,
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
) -> "DefaultIndexingEmbedder":
return cls(
model_name=search_settings.model_name,
@@ -215,5 +213,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url=search_settings.api_url,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
callback=callback,
heartbeat=heartbeat,
)

View File

@@ -1,15 +1,41 @@
from abc import ABC
from abc import abstractmethod
import abc
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from danswer.db.index_attempt import get_index_attempt
from danswer.utils.logger import setup_logger
logger = setup_logger()
class IndexingHeartbeatInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
class Heartbeat(abc.ABC):
"""Useful for any long-running work that goes through a bunch of items
and needs to occasionally give updates on progress.
e.g. chunking, embedding, updating vespa, etc."""
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
@abc.abstractmethod
def heartbeat(self, metadata: Any = None) -> None:
raise NotImplementedError
@abstractmethod
def progress(self, tag: str, amount: int) -> None:
"""Send progress updates to the caller."""
class IndexingHeartbeat(Heartbeat):
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
self.cnt = 0
self.index_attempt_id = index_attempt_id
self.db_session = db_session
self.freq = freq
def heartbeat(self, metadata: Any = None) -> None:
self.cnt += 1
if self.cnt % self.freq == 0:
index_attempt = get_index_attempt(
db_session=self.db_session, index_attempt_id=self.index_attempt_id
)
if index_attempt:
index_attempt.time_updated = func.now()
self.db_session.commit()
else:
logger.error("Index attempt not found, this should not happen!")

View File

@@ -1,9 +1,7 @@
import traceback
from functools import partial
from http import HTTPStatus
from typing import Protocol
import httpx
from pydantic import BaseModel
from pydantic import ConfigDict
from sqlalchemy.orm import Session
@@ -22,8 +20,7 @@ from danswer.db.document import get_documents_by_ids
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document import update_docs_last_modified__no_commit
from danswer.db.document import update_docs_updated_at__no_commit
from danswer.db.document import upsert_document_by_connector_credential_pair
from danswer.db.document import upsert_documents
from danswer.db.document import upsert_documents_complete
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.index_attempt import create_index_attempt_error
from danswer.db.models import Document as DBDocument
@@ -34,7 +31,7 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentMetadata
from danswer.indexing.chunker import Chunker
from danswer.indexing.embedder import IndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.utils.logger import setup_logger
@@ -65,7 +62,7 @@ def _upsert_documents_in_db(
db_session: Session,
) -> None:
# Metadata here refers to basic document info, not metadata about the actual content
document_metadata_list: list[DocumentMetadata] = []
doc_m_batch: list[DocumentMetadata] = []
for doc in documents:
first_link = next(
(section.link for section in doc.sections if section.link), ""
@@ -80,9 +77,12 @@ def _upsert_documents_in_db(
secondary_owners=get_experts_stores_representations(doc.secondary_owners),
from_ingestion_api=doc.from_ingestion_api,
)
document_metadata_list.append(db_doc_metadata)
doc_m_batch.append(db_doc_metadata)
upsert_documents(db_session, document_metadata_list)
upsert_documents_complete(
db_session=db_session,
document_metadata_batch=doc_m_batch,
)
# Insert document content metadata
for doc in documents:
@@ -95,25 +95,21 @@ def _upsert_documents_in_db(
document_id=doc.id,
db_session=db_session,
)
continue
create_or_add_document_tag(
tag_key=k,
tag_value=v,
source=doc.source,
document_id=doc.id,
db_session=db_session,
)
else:
create_or_add_document_tag(
tag_key=k,
tag_value=v,
source=doc.source,
document_id=doc.id,
db_session=db_session,
)
def get_doc_ids_to_update(
documents: list[Document], db_docs: list[DBDocument]
) -> list[Document]:
"""Figures out which documents actually need to be updated. If a document is already present
and the `updated_at` hasn't changed, we shouldn't need to do anything with it.
NB: Still need to associate the document in the DB if multiple connectors are
indexing the same doc."""
and the `updated_at` hasn't changed, we shouldn't need to do anything with it."""
id_update_time_map = {
doc.id: doc.doc_updated_at for doc in db_docs if doc.doc_updated_at
}
@@ -156,14 +152,6 @@ def index_doc_batch_with_handler(
tenant_id=tenant_id,
)
except Exception as e:
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE:
logger.error(
"NOTE: HTTP Status 507 Insufficient Storage indicates "
"you need to allocate more memory or disk space to the "
"Vespa/index container."
)
if INDEXING_EXCEPTION_LIMIT == 0:
raise
@@ -207,9 +195,9 @@ def index_doc_batch_prepare(
db_session: Session,
ignore_time_skip: bool = False,
) -> DocumentBatchPrepareContext | None:
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
"""This sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
This preceeds indexing it into the actual document index."""
documents: list[Document] = []
documents = []
for document in document_batch:
empty_contents = not any(section.text.strip() for section in document.sections)
if (
@@ -224,58 +212,43 @@ def index_doc_batch_prepare(
logger.warning(
f"Skipping document with ID {document.id} as it has neither title nor content."
)
continue
if document.title is not None and not document.title.strip() and empty_contents:
elif (
document.title is not None and not document.title.strip() and empty_contents
):
# The title is explicitly empty ("" and not None) and the document is empty
# so when building the chunk text representation, it will be empty and unuseable
logger.warning(
f"Skipping document with ID {document.id} as the chunks will be empty."
)
continue
else:
documents.append(document)
documents.append(document)
# Create a trimmed list of docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
document_ids: list[str] = [document.id for document in documents]
document_ids = [document.id for document in documents]
db_docs: list[DBDocument] = get_documents_by_ids(
db_session=db_session,
document_ids=document_ids,
)
# Skip indexing docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
updatable_docs = (
get_doc_ids_to_update(documents=documents, db_docs=db_docs)
if not ignore_time_skip
else documents
)
# for all updatable docs, upsert into the DB
# Does not include doc_updated_at which is also used to indicate a successful update
if updatable_docs:
_upsert_documents_in_db(
documents=updatable_docs,
index_attempt_metadata=index_attempt_metadata,
db_session=db_session,
)
logger.info(
f"Upserted {len(updatable_docs)} changed docs out of "
f"{len(documents)} total docs into the DB"
)
# for all docs, upsert the document to cc pair relationship
upsert_document_by_connector_credential_pair(
db_session,
index_attempt_metadata.connector_id,
index_attempt_metadata.credential_id,
document_ids,
)
# No docs to process because the batch is empty or every doc was already indexed
# No docs to update either because the batch is empty or every doc was already indexed
if not updatable_docs:
return None
# Create records in the source of truth about these documents,
# does not include doc_updated_at which is also used to indicate a successful update
_upsert_documents_in_db(
documents=documents,
index_attempt_metadata=index_attempt_metadata,
db_session=db_session,
)
id_to_db_doc_map = {doc.id: doc for doc in db_docs}
return DocumentBatchPrepareContext(
updatable_docs=updatable_docs, id_to_db_doc_map=id_to_db_doc_map
@@ -296,10 +269,7 @@ def index_doc_batch(
) -> tuple[int, int]:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements
Returns a tuple where the first element is the number of new docs and the
second element is the number of chunks."""
memory requirements"""
no_access = DocumentAccess.build(
user_emails=[],
@@ -342,9 +312,9 @@ def index_doc_batch(
# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.
# we still write data here for the immediate and most likely correct sync, but
# we still write data here for immediate and most likely correct sync, but
# to resolve this, an update of the last modified field at the end of this loop
# always triggers a final metadata sync via the celery queue
# always triggers a final metadata sync
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
@@ -381,8 +351,7 @@ def index_doc_batch(
ids_to_new_updated_at = {}
for doc in successful_docs:
last_modified_ids.append(doc.id)
# doc_updated_at is the source's idea (on the other end of the connector)
# of when the doc was last modified
# doc_updated_at is the connector source's idea of when the doc was last modified
if doc.doc_updated_at is None:
continue
ids_to_new_updated_at[doc.id] = doc.doc_updated_at
@@ -397,13 +366,10 @@ def index_doc_batch(
db_session.commit()
result = (
len([r for r in insertion_records if r.already_existed is False]),
len(access_aware_chunks),
return len([r for r in insertion_records if r.already_existed is False]), len(
access_aware_chunks
)
return result
def build_indexing_pipeline(
*,
@@ -414,7 +380,6 @@ def build_indexing_pipeline(
ignore_time_skip: bool = False,
attempt_id: int | None = None,
tenant_id: str | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
search_settings = get_current_search_settings(db_session)
@@ -441,8 +406,13 @@ def build_indexing_pipeline(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass,
enable_large_chunks=enable_large_chunks,
# after every doc, update status in case there are a bunch of really long docs
callback=callback,
# after every doc, update status in case there are a bunch of
# really long docs
heartbeat=IndexingHeartbeat(
index_attempt_id=attempt_id, db_session=db_session, freq=1
)
if attempt_id
else None,
)
return partial(

View File

@@ -1,6 +1,12 @@
import abc
from collections.abc import Mapping
from collections.abc import Sequence
from typing import TypeAlias
from danswer.utils.special_types import JSON_ro
JSON_ro: TypeAlias = (
Mapping[str, "JSON_ro"] | Sequence["JSON_ro"] | str | int | float | bool | None
)
class KvKeyNotFoundError(Exception):

View File

@@ -11,11 +11,11 @@ from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import is_valid_schema_name
from danswer.db.models import KVStore
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KeyValueStore
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.special_types import JSON_ro
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

View File

@@ -233,8 +233,6 @@ class Answer:
# DEBUG: good breakpoint
stream = self.llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tool_choice=(
@@ -265,7 +263,6 @@ class Answer:
message_history=self.message_history,
llm_config=self.llm.config,
single_message_history=self.single_message_history,
raw_user_text=self.question,
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)

View File

@@ -58,7 +58,6 @@ class AnswerPromptBuilder:
user_message: HumanMessage,
message_history: list[PreviousMessage],
llm_config: LLMConfig,
raw_user_text: str,
single_message_history: str | None = None,
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)
@@ -89,8 +88,6 @@ class AnswerPromptBuilder:
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
self.raw_user_message = raw_user_text
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:
self.system_message_and_token_cnt = None

Some files were not shown because too many files have changed in this diff Show More