Compare commits

..

62 Commits

Author SHA1 Message Date
pablodanswer
92be55c9d7 k 2024-12-01 17:54:22 -08:00
pablodanswer
dc8fa4c3cb update workflows 2024-12-01 17:50:52 -08:00
pablodanswer
c5aa64e3fb fix tests 2024-12-01 17:47:58 -08:00
pablodanswer
f4dea0821f k 2024-12-01 17:17:58 -08:00
pablodanswer
1ed4002902 nits 2024-12-01 17:11:37 -08:00
pablodanswer
952893d7f0 tests fixed 2024-12-01 17:10:05 -08:00
Yuhong Sun
3432d932d1 Citation code comments 2024-12-01 14:10:11 -08:00
Yuhong Sun
9bd0cb9eb5 Fix Citation Minor Bugs (#3294) 2024-12-01 13:55:24 -08:00
Chris Weaver
f12eb4a5cf Fix assistant prompt zero-ing (#3293) 2024-11-30 04:45:40 +00:00
Chris Weaver
16863de0aa Improve model token limit detection (#3292)
* Properly find context window for ollama llama

* Better ollama support + upgrade litellm

* Ugprade OpenAI as well

* Fix mypy
2024-11-30 04:42:56 +00:00
Weves
63d1eefee5 Add read_only=True for xlsx parsing 2024-11-28 16:02:02 -08:00
pablodanswer
e338677896 order seeding 2024-11-28 15:41:10 -08:00
hagen-danswer
7be80c4af9 increased the pagination limit for confluence spaces (#3288) 2024-11-28 19:04:38 +00:00
rkuo-danswer
7f1e4a02bf Feature/kill indexing (#3213)
* checkpoint

* add celery termination of the task

* rename to RedisConnectorPermissionSyncPayload, add RedisLock to more places, add get_active_search_settings

* rename payload

* pretty sure these weren't named correctly

* testing in progress

* cleanup

* remove space

* merge fix

* three dots animation on Pausing

* improve messaging when connector is stopped or killed and animate buttons

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-28 05:32:45 +00:00
rkuo-danswer
5be7d27285 use indexing flag in db for manually triggering indexing (#3264)
* use indexing flag in db for manually trigger indexing

* add comment.

* only try to release the lock if we actually succeeded with the lock

* ensure we don't trigger manual indexing on anything but the primary search settings

* comment usage of primary search settings

* run check for indexing immediately after indexing triggers are set

* reorder fix
2024-11-28 01:34:34 +00:00
Weves
fd84b7a768 Remove duplicate API key router 2024-11-27 16:30:59 -08:00
Subash-Mohan
36941ae663 fix: Cannot configure API keys #3191 2024-11-27 16:25:00 -08:00
Matthew Holland
212353ed4a Fixed default feedback options 2024-11-27 16:23:52 -08:00
Richard Kuo (Danswer)
eb8708f770 the word "error" might be throwing off sentry 2024-11-27 14:31:21 -08:00
Chris Weaver
ac448956e9 Add handling for rate limiting (#3280) 2024-11-27 14:22:15 -08:00
pablodanswer
634a0b9398 no stack by default (#3278) 2024-11-27 20:58:21 +00:00
hagen-danswer
09d3e47c03 Perm sync behavior change (#3262)
* Change external permissions behavior

* fixed behavior

* added error handling

* LLM the goat

* comment

* simplify

* fixed

* done

* limits increased

* added a ton of logging

* uhhhh
2024-11-27 20:04:15 +00:00
pablodanswer
9c0cc94f15 refresh router -> refresh assistants (#3271) 2024-11-27 19:11:58 +00:00
hagen-danswer
07dfde2209 add continue in danswer button to slack bot responses (#3239)
* all done except routing

* fixed initial changes

* added backend endpoint for duplicating a chat session from Slack

* got chat duplication routing done

* got login routing working

* improved answer handling

* finished all checks

* finished all!

* made sure it works with google oauth

* dont remove that lol

* fixed weird thing

* bad comments
2024-11-27 18:25:38 +00:00
pablodanswer
28e2b78b2e Fix search dropdown (#3269)
* validate dropdown

* validate

* update organization

* move to utils
2024-11-27 16:10:07 +00:00
Emerson Gomes
0553062ac6 Adds icons for Google Gemini models and custom model icons for L… (#3218)
* Add description for Google Gemini models and custom model icons for LiteLLM (OpenAI) proxied models

* Adds Vertex AI aliases for Claude

---------

Co-authored-by: Emerson Gomes <emerson.gomes@thalesgroup.com>
2024-11-26 10:13:21 -08:00
hagen-danswer
284e375ba3 Merge pull request #3257 from danswer-ai/minor-perm-sync
Improved logging for confluence doc sync and robust user creation
2024-11-26 09:59:38 -08:00
hagen-danswer
1f2f7d0ac2 Improved logging for confluence doc sync and robust user creation 2024-11-26 08:51:15 -08:00
pablodanswer
2ecc28b57d remove unused stripe promise (#3248) 2024-11-26 01:50:39 +00:00
rkuo-danswer
77cf9b3539 improve messaging and UI around cleanup of leftover index attempts (#3247)
* improve messaging and UI around cleanup of leftover index attempts

* add tag on init
2024-11-25 22:27:14 +00:00
Weves
076ce2ebd0 Saml fix 2024-11-25 09:12:43 -08:00
pablodanswer
b625ee32a7 File handling cleanup (#3240)
* fix google sites connector

* minior cleanup

* rm comments
2024-11-25 04:06:47 +00:00
Richard Kuo (Danswer)
c32b93fcc3 increase indexing worker concurrency to 3 2024-11-24 18:11:58 -08:00
pablodanswer
1c8476072e Assistant cleanup (#3236)
* minor cleanup

* ensure users don't modify built-in attributes of assistants

* update sidebar

* k

* update update flow + assistant creation
2024-11-25 00:13:34 +00:00
Chris Weaver
7573416ca1 Fix API keys for MIT users (#3237) 2024-11-24 16:55:19 -08:00
Yuhong Sun
86d8666481 Add Test Case 2024-11-24 15:42:14 -08:00
Yuhong Sun
8abcde91d4 Fix Test (#3242) 2024-11-24 14:31:28 -08:00
Yuhong Sun
3466451d51 Fix Prompt for Non Function Calling LLMs (#3241) 2024-11-24 14:16:57 -08:00
Yuhong Sun
413891f143 Token Level Log (#3238) 2024-11-23 18:41:50 -08:00
Yuhong Sun
7a0a4d4b79 Remove Deprecated Endpoints (#3235) 2024-11-23 14:44:23 -08:00
Yuhong Sun
a3439605a5 Remove Dead Code (#3234) 2024-11-23 14:31:59 -08:00
pablodanswer
694e79f5e1 minor enforcement of CSV length for internal processing (#3109) 2024-11-23 21:05:30 +00:00
pablodanswer
5dfafc8612 minor calendar cleanup (#3219) 2024-11-23 21:01:05 +00:00
Yuhong Sun
62a4aa10db Refactor Search (#3233) 2024-11-23 13:42:54 -08:00
Yuhong Sun
a357cdc4c9 Remove Dead Code (#3232) 2024-11-23 13:21:27 -08:00
Yuhong Sun
84615abfdd Seeding (#3231) 2024-11-23 13:12:42 -08:00
pablodanswer
8ae6b1960b Bugfix/usage report (#3075)
* fix pagination

* update side

* fixed query history

* minor update

* minor update

* typing
2024-11-23 20:11:39 +00:00
James Jordan
d9b87bbbc2 Fixed 400 error when author of ticket is no longer an active user in a Zendesk account. (#3168) 2024-11-23 12:15:38 -08:00
Sanju Lokuhitige
a0065b01af Update CONTRIBUTING.md (#3112)
fix Formatting and Linting hyperlink
2024-11-23 12:13:23 -08:00
pablodanswer
c5306148a3 Ensure daterange not consistently re rendered (#3229)
* ensure daterange not consistently re rendered

* minor clean up
2024-11-23 19:35:00 +00:00
hagen-danswer
1e17934de4 Merge pull request #3214 from danswer-ai/fix-slack-ui
cleaned up new slack bot creation
2024-11-23 10:53:47 -08:00
pablodanswer
93add96ccc Various Nits (#3228) 2024-11-23 10:53:24 -08:00
rkuo-danswer
3a466a4b08 add minimal retries to confluence probe (#3222)
* add minimal retries to confluence probe

* name variable correctly
2024-11-23 17:11:15 +00:00
hagen-danswer
85cbd9caed Increased slim doc batch size for confluence connector (#3221) 2024-11-23 00:42:15 +00:00
pablodanswer
9dc23bf3e7 revert to previous doc select logic (#3217)
* revert to previous doc select logic

* k
2024-11-22 23:26:53 +00:00
hagen-danswer
e32809f7ca moved it outside 2024-11-22 14:59:58 -08:00
hagen-danswer
3e58f9f8ab fixed ugly stuff 2024-11-22 14:39:55 -08:00
pablodanswer
2381c8d498 Refresh all assistants on assistant refresh (#3216)
* k

* k
2024-11-22 22:38:23 +00:00
hagen-danswer
c6dadb24dc cleaned up new slack bot creation 2024-11-22 11:53:51 -08:00
hagen-danswer
5dc07d4178 Each section is now cleaned before being chunked (#3210)
* Each section is now cleaned before being chunked

* k

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-11-22 19:06:19 +00:00
Chris Weaver
129c8f8faf Add start/end date ability for query history as CSV endpoint (#3211) 2024-11-22 18:29:13 +00:00
pablodanswer
67bfcabbc5 llm provider causing re render in effect (#3205)
* llm provider causing re render in effect

* clean

* unused

* k
2024-11-22 16:53:24 +00:00
220 changed files with 3156 additions and 1812 deletions

111
.github/workflows/multi-tenant-tests.yml vendored Normal file
View File

@@ -0,0 +1,111 @@
name: Run Multi-Tenant Integration Tests
on:
workflow_dispatch:
pull_request:
branches:
- main
- "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
multi-tenant-integration-tests:
runs-on:
[runs-on, runner=8cpu-linux-x64, ram=16, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Pull Required Docker Images
run: |
docker pull danswer/danswer-backend:latest
docker tag danswer/danswer-backend:latest danswer/danswer-backend:test
docker pull danswer/danswer-model-server:latest
docker tag danswer/danswer-model-server:latest danswer/danswer-model-server:test
docker pull danswer/danswer-web-server:latest
docker tag danswer/danswer-web-server:latest danswer/danswer-web-server:test
docker pull danswer/control-tenants-service:latest
docker tag danswer/control-tenants-service:latest danswer/control-tenants-service:test
- name: Build Integration Test Docker Image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-integration:test
push: false
load: true
- name: Start Docker Containers for Multi-Tenant Tests
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
MULTI_TENANT=true \
INTEGRATION_TEST_MODE=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
CONTROL_TENANTS_SERVICE_IMAGE=danswer/control-tenants-service:test \
docker compose -f docker-compose.dev.yml -f docker-compose.multi-tenant.yml -p danswer-stack up -d
- name: Run Multi-Tenant Integration Tests
run: |
echo "Running multi-tenant integration tests..."
docker run --rm --network danswer-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \
danswer/danswer-integration:test \
/app/tests/integration/multitenant_tests
continue-on-error: true
id: run_multitenant_tests
- name: Check Multi-Tenant Test Results
run: |
if [ ${{ steps.run_multitenant_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
- name: Stop Docker Containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Upload Logs
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log

View File

@@ -8,7 +8,7 @@ on:
pull_request:
branches:
- main
- 'release/**'
- "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -16,11 +16,12 @@ env:
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
runs-on:
[runs-on, runner=8cpu-linux-x64, ram=16, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -36,9 +37,9 @@ jobs:
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Web Docker image
run: |
@@ -50,7 +51,7 @@ jobs:
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
@@ -75,7 +76,7 @@ jobs:
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
@@ -88,58 +89,7 @@ jobs:
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# Start containers for multi-tenant tests
- name: Start Docker containers for multi-tenant tests
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
MULTI_TENANT=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker_multi_tenant
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
- name: Run Multi-Tenant Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \
danswer/danswer-integration:test \
/app/tests/integration/multitenant_tests
continue-on-error: true
id: run_multitenant_tests
- name: Check multi-tenant test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
- name: Stop multi-tenant Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Start Docker containers
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
@@ -153,12 +103,12 @@ jobs:
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f danswer-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
@@ -229,7 +179,7 @@ jobs:
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4

View File

@@ -32,7 +32,7 @@ To contribute to this project, please follow the
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
See the [Formatting and Linting](#-formatting-and-linting) section for how to run these checks locally.
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
### Getting Help 🙋

View File

@@ -0,0 +1,45 @@
"""remove default bot
Revision ID: 6d562f86c78b
Revises: 177de57c21c9
Create Date: 2024-11-22 11:51:29.331336
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6d562f86c78b"
down_revision = "177de57c21c9"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
sa.text(
"""
DELETE FROM slack_bot
WHERE name = 'Default Bot'
AND bot_token = ''
AND app_token = ''
AND NOT EXISTS (
SELECT 1 FROM slack_channel_config
WHERE slack_channel_config.slack_bot_id = slack_bot.id
)
"""
)
)
def downgrade() -> None:
op.execute(
sa.text(
"""
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
SELECT 'Default Bot', true, '', ''
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
RETURNING id;
"""
)
)

View File

@@ -9,8 +9,8 @@ from alembic import op
import sqlalchemy as sa
from danswer.db.models import IndexModelStatus
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"

View File

@@ -0,0 +1,35 @@
"""add web ui option to slack config
Revision ID: 93560ba1b118
Revises: 6d562f86c78b
Create Date: 2024-11-24 06:36:17.490612
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "93560ba1b118"
down_revision = "6d562f86c78b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add show_continue_in_web_ui with default False to all existing channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb
WHERE NOT channel_config ? 'show_continue_in_web_ui'
"""
)
def downgrade() -> None:
# Remove show_continue_in_web_ui from all channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config - 'show_continue_in_web_ui'
"""
)

View File

@@ -0,0 +1,30 @@
"""add indexing trigger to cc_pair
Revision ID: abe7378b8217
Revises: 6d562f86c78b
Create Date: 2024-11-26 19:09:53.481171
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "abe7378b8217"
down_revision = "93560ba1b118"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"indexing_trigger",
sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "indexing_trigger")

View File

@@ -49,7 +49,7 @@ from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
@@ -80,8 +80,8 @@ from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
@@ -609,7 +609,7 @@ optional_fastapi_current_user = fastapi_users.current_user(active=True, optional
async def optional_user_(
request: Request,
user: User | None,
db_session: Session,
async_db_session: AsyncSession,
) -> User | None:
"""NOTE: `request` and `db_session` are not used here, but are included
for the EE version of this function."""
@@ -618,13 +618,21 @@ async def optional_user_(
async def optional_user(
request: Request,
db_session: Session = Depends(get_session),
async_db_session: AsyncSession = Depends(get_async_session),
user: User | None = Depends(optional_fastapi_current_user),
) -> User | None:
versioned_fetch_user = fetch_versioned_implementation(
"danswer.auth.users", "optional_user_"
)
return await versioned_fetch_user(request, user, db_session)
user = await versioned_fetch_user(request, user, async_db_session)
# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
return user
async def double_check_user(
@@ -910,8 +918,8 @@ def get_oauth_router(
return router
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
async def api_key_dep(
request: Request, async_db_session: AsyncSession = Depends(get_async_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
@@ -921,7 +929,7 @@ def api_key_dep(
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")

View File

@@ -24,7 +24,7 @@ from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.engine import SqlEngine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
@@ -165,13 +165,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
continue
failure_reason = (
f"Orphaned index attempt found on startup: "
f"Canceling leftover index attempt found on startup: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
mark_attempt_failed(attempt.id, db_session, failure_reason)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
@worker_ready.connect

View File

@@ -5,7 +5,6 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@@ -37,7 +36,7 @@ class TaskDependencyError(RuntimeError):
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -60,7 +59,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
self.app, cc_pair_id, db_session, lock_beat, tenant_id
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
@@ -86,7 +85,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
app: Celery,
cc_pair_id: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:

View File

@@ -8,6 +8,7 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.access.models import DocExternalAccess
from danswer.background.celery.apps.app_base import task_logger
@@ -27,7 +28,7 @@ from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import doc_permission_sync_ctx
@@ -138,7 +139,7 @@ def try_creating_permissions_sync_task(
LOCK_TIMEOUT = 30
lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
)
@@ -162,7 +163,7 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
app.send_task(
result = app.send_task(
"connector_permission_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
@@ -174,8 +175,8 @@ def try_creating_permissions_sync_task(
)
# set a basic fence to start
payload = RedisConnectorPermissionSyncData(
started=None,
payload = RedisConnectorPermissionSyncPayload(
started=None, celery_task_id=result.id
)
redis_connector.permissions.set_fence(payload)
@@ -241,13 +242,17 @@ def connector_permission_sync_generator_task(
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
raise ValueError(f"No doc sync func found for {source_type}")
raise ValueError(
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
)
logger.info(f"Syncing docs for {source_type}")
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
payload = RedisConnectorPermissionSyncData(
started=datetime.now(timezone.utc),
)
payload = redis_connector.permissions.payload
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
payload.started = datetime.now(timezone.utc)
redis_connector.permissions.set_fence(payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)

View File

@@ -8,6 +8,7 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
@@ -24,6 +25,9 @@ from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
@@ -49,7 +53,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
if cc_pair.access_type != AccessType.SYNC:
return False
# skip pruning if not active
# skip external group sync if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
@@ -107,7 +111,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
tasks_created = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
@@ -125,7 +129,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()
def try_creating_permissions_sync_task(
def try_creating_external_group_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
@@ -156,7 +160,7 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
_ = app.send_task(
result = app.send_task(
"connector_external_group_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
@@ -166,8 +170,13 @@ def try_creating_permissions_sync_task(
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
# set a basic fence to start
redis_connector.external_group_sync.set_fence(True)
payload = RedisConnectorExternalGroupSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=result.id,
)
redis_connector.external_group_sync.set_fence(payload)
except Exception:
task_logger.exception(
@@ -195,7 +204,7 @@ def connector_external_group_sync_generator_task(
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
Permission sync task that handles external group syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
@@ -203,7 +212,7 @@ def connector_external_group_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
lock: RedisLock = r.lock(
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
@@ -228,9 +237,13 @@ def connector_external_group_sync_generator_task(
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None:
raise ValueError(f"No external group sync func found for {source_type}")
raise ValueError(
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
)
logger.info(f"Syncing docs for {source_type}")
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
@@ -249,7 +262,6 @@ def connector_external_group_sync_generator_task(
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
@@ -260,6 +272,6 @@ def connector_external_group_sync_generator_task(
raise e
finally:
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
redis_connector.external_group_sync.set_fence(False)
redis_connector.external_group_sync.set_fence(None)
if lock.owned():
lock.release()

View File

@@ -25,11 +25,13 @@ from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector import mark_ccpair_with_indexing_trigger
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingMode
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
@@ -37,12 +39,13 @@ from danswer.db.index_attempt import delete_index_attempt
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
@@ -77,7 +80,7 @@ class IndexingCallback(IndexingHeartbeatInterface):
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = ""
self.last_tag: str = "IndexingCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
def should_stop(self) -> bool:
@@ -159,7 +162,7 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
)
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
tasks_created = 0
locked = False
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
@@ -172,6 +175,8 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
if not lock_beat.acquire(blocking=False):
return None
locked = True
# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
@@ -205,17 +210,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
for search_settings_instance in search_settings:
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
for search_settings_instance in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
@@ -231,22 +229,46 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
if not _should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings) > 1,
search_settings_primary=search_settings_primary,
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
continue
reindex = False
if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
reindex = True
task_logger.info(
f"Connector indexing manual trigger detected: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"indexing_mode={cc_pair.indexing_trigger}"
)
mark_ccpair_with_indexing_trigger(
cc_pair.id, None, db_session
)
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
self.app,
cc_pair,
search_settings_instance,
False,
reindex,
db_session,
r,
tenant_id,
@@ -256,7 +278,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"Connector indexing queued: "
f"index_attempt={attempt_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"search_settings={search_settings_instance.id}"
)
tasks_created += 1
@@ -281,7 +303,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -289,13 +310,14 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
if locked:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
return tasks_created
@@ -304,6 +326,7 @@ def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
@@ -368,6 +391,11 @@ def _should_index(
):
return False
if search_settings_primary:
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index:
return True
@@ -495,8 +523,11 @@ def try_creating_indexing_task(
return index_attempt_id
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
@shared_task(
name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True
)
def connector_indexing_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
@@ -509,6 +540,10 @@ def connector_indexing_proxy_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if not self.request.id:
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
job = client.submit(
@@ -537,8 +572,30 @@ def connector_indexing_proxy_task(
f"search_settings={search_settings_id}"
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
while True:
sleep(10)
sleep(5)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing proxy - termination signal detected: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
job.cancel()
break
# do nothing for ongoing jobs that haven't been stopped
if not job.done():

View File

@@ -46,6 +46,7 @@ from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
@@ -58,7 +59,7 @@ from danswer.redis.redis_connector_credential_pair import RedisConnectorCredenti
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
@@ -588,7 +589,7 @@ def monitor_ccpair_permissions_taskset(
if remaining > 0:
return
payload: RedisConnectorPermissionSyncData | None = (
payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
@@ -596,9 +597,7 @@ def monitor_ccpair_permissions_taskset(
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
redis_connector.permissions.taskset_clear()
redis_connector.permissions.generator_clear()
redis_connector.permissions.set_fence(None)
redis_connector.permissions.reset()
def monitor_ccpair_indexing_taskset(
@@ -678,11 +677,15 @@ def monitor_ccpair_indexing_taskset(
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
redis_connector_index.reset()
return
@@ -692,6 +695,7 @@ def monitor_ccpair_indexing_taskset(
task_logger.info(
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -724,7 +728,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r)
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)

View File

@@ -1,6 +1,8 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.background.celery.apps.beat import celery_app
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app = celery_app
app: Celery = celery_app

View File

@@ -1,8 +1,10 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app = fetch_versioned_implementation(
app: Celery = fetch_versioned_implementation(
"danswer.background.celery.apps.primary", "celery_app"
)

View File

@@ -19,6 +19,7 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
@@ -87,6 +88,10 @@ def _get_connector_runner(
)
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
@@ -208,9 +213,7 @@ def _run_indexing(
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise RuntimeError(
"_run_indexing: Connector stop signal detected"
)
raise ConnectorStopSignal("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
@@ -304,26 +307,16 @@ def _run_indexing(
)
except Exception as e:
logger.exception(
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
)
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
if isinstance(e, ConnectorStopSignal):
mark_attempt_canceled(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
reason=str(e),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
@@ -335,6 +328,37 @@ def _run_indexing(
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
else:
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure

View File

@@ -7,10 +7,10 @@ from sqlalchemy.orm import Session
from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.context.search.models import InferenceSection
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.models import ChatMessage
from danswer.llm.answering.models import PreviousMessage
from danswer.search.models import InferenceSection
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -6,10 +6,10 @@ from typing import Any
from pydantic import BaseModel
from danswer.configs.constants import DocumentSource
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import SearchType
from danswer.context.search.models import RetrievalDocs
from danswer.context.search.models import SearchResponse
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType

View File

@@ -23,6 +23,16 @@ from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.context.search.enums import OptionalSearchSetting
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import SearchType
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RetrievalDetails
from danswer.context.search.retrieval.search_runner import inference_sections_from_ids
from danswer.context.search.utils import chunks_or_sections_to_search_docs
from danswer.context.search.utils import dedupe_documents
from danswer.context.search.utils import drop_llm_indices
from danswer.context.search.utils import relevant_sections_to_indices
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -56,16 +66,6 @@ from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line

View File

@@ -1,115 +0,0 @@
from typing_extensions import TypedDict # noreorder
from pydantic import BaseModel
from danswer.prompts.chat_tools import DANSWER_TOOL_DESCRIPTION
from danswer.prompts.chat_tools import DANSWER_TOOL_NAME
from danswer.prompts.chat_tools import TOOL_FOLLOWUP
from danswer.prompts.chat_tools import TOOL_LESS_FOLLOWUP
from danswer.prompts.chat_tools import TOOL_LESS_PROMPT
from danswer.prompts.chat_tools import TOOL_TEMPLATE
from danswer.prompts.chat_tools import USER_INPUT
class ToolInfo(TypedDict):
name: str
description: str
class DanswerChatModelOut(BaseModel):
model_raw: str
action: str
action_input: str
def call_tool(
model_actions: DanswerChatModelOut,
) -> str:
raise NotImplementedError("There are no additional tool integrations right now")
def form_user_prompt_text(
query: str,
tool_text: str | None,
hint_text: str | None,
user_input_prompt: str = USER_INPUT,
tool_less_prompt: str = TOOL_LESS_PROMPT,
) -> str:
user_prompt = tool_text or tool_less_prompt
user_prompt += user_input_prompt.format(user_input=query)
if hint_text:
if user_prompt[-1] != "\n":
user_prompt += "\n"
user_prompt += "\nHint: " + hint_text
return user_prompt.strip()
def form_tool_section_text(
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
) -> str | None:
if not tools and not retrieval_enabled:
return None
if retrieval_enabled and tools:
tools.append(
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
)
tools_intro = []
if tools:
num_tools = len(tools)
for tool in tools:
description_formatted = tool["description"].replace("\n", " ")
tools_intro.append(f"> {tool['name']}: {description_formatted}")
prefix = "Must be one of " if num_tools > 1 else "Must be "
tools_intro_text = "\n".join(tools_intro)
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
else:
return None
return template.format(
tool_overviews=tools_intro_text, tool_names=tool_names_text
).strip()
def form_tool_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_FOLLOWUP,
ignore_hint: bool = False,
) -> str:
# If multi-line query, it likely confuses the model more than helps
if "\n" not in query:
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
else:
optional_reminder = ""
if not ignore_hint and hint_text:
hint_text_spaced = f"\nHint: {hint_text}\n"
else:
hint_text_spaced = ""
return tool_followup_prompt.format(
tool_output=tool_output,
optional_reminder=optional_reminder,
hint=hint_text_spaced,
).strip()
def form_tool_less_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
) -> str:
hint = f"Hint: {hint_text}" if hint_text else ""
return tool_followup_prompt.format(
context_str=tool_output, user_query=query, hint_text=hint
).strip()

View File

@@ -234,7 +234,7 @@ except ValueError:
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
)
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 1
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 3
try:
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
if not env_value:
@@ -422,6 +422,9 @@ LOG_ALL_MODEL_INTERACTIONS = (
LOG_DANSWER_MODEL_INTERACTIONS = (
os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true"
)
LOG_INDIVIDUAL_MODEL_TOKENS = (
os.environ.get("LOG_INDIVIDUAL_MODEL_TOKENS", "").lower() == "true"
)
# If set to `true` will enable additional logs about Vespa query performance
# (time spent on finding the right docs + time spent fetching summaries from disk)
LOG_VESPA_TIMING_INFORMATION = (

View File

@@ -1,9 +1,9 @@
import os
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
PERSONAS_YAML = "./danswer/chat/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/chat/input_prompts.yaml"
PROMPTS_YAML = "./danswer/seeding/prompts.yaml"
PERSONAS_YAML = "./danswer/seeding/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/seeding/input_prompts.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking
@@ -17,9 +17,6 @@ MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
# ~3k input, half for docs, half for chat history + prompts
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
# For selecting a different LLM question-answering prompt format
# Valid values: default, cot, weak
QA_PROMPT_OVERRIDE = os.environ.get("QA_PROMPT_OVERRIDE") or None
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
# Capped in Vespa at 0.5
DOC_TIME_DECAY = float(
@@ -27,8 +24,6 @@ DOC_TIME_DECAY = float(
)
BASE_RECENCY_DECAY = 0.5
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
# Currently this next one is not configurable via env
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
# Note this is not in any of the deployment configs yet
# Currently only applies to search flow not chat

View File

@@ -70,7 +70,9 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
)
# Typically, GenAI models nowadays are at least 4K tokens
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
)
# Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible

View File

@@ -51,6 +51,8 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
"restrictions.read.restrictions.group",
]
_SLIM_DOC_BATCH_SIZE = 5000
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
@@ -263,6 +265,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
@@ -286,6 +289,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
doc_metadata_list.append(
SlimDocument(
@@ -297,5 +301,8 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
perm_sync_data=perm_sync_data,
)
)
yield doc_metadata_list
doc_metadata_list = []
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
yield doc_metadata_list

View File

@@ -120,7 +120,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return cast(F, wrapped_call)
_DEFAULT_PAGINATION_LIMIT = 100
_DEFAULT_PAGINATION_LIMIT = 1000
class OnyxConfluence(Confluence):
@@ -294,14 +294,17 @@ def _validate_connector_configuration(
wiki_base: str,
) -> None:
# test connection with direct client, no retries
confluence_client_without_retries = Confluence(
confluence_client_with_minimal_retries = Confluence(
api_version="cloud" if is_cloud else "latest",
url=wiki_base.rstrip("/"),
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=6,
max_backoff_seconds=10,
)
spaces = confluence_client_without_retries.get_all_spaces(limit=1)
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
if not spaces:
raise RuntimeError(

View File

@@ -102,13 +102,21 @@ def _get_tickets(
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
author_data = client.make_request(f"users/{author_id}", {})
user = author_data.get("user")
return (
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
if user and user.get("name") and user.get("email")
else None
)
# Skip fetching if author_id is invalid
if not author_id or author_id == "-1":
return None
try:
author_data = client.make_request(f"users/{author_id}", {})
user = author_data.get("user")
return (
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
if user and user.get("name") and user.get("email")
else None
)
except requests.exceptions.HTTPError:
# Handle any API errors gracefully
return None
def _article_to_document(

View File

@@ -8,13 +8,13 @@ from pydantic import field_validator
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.enums import OptionalSearchSetting
from danswer.context.search.enums import SearchType
from danswer.db.models import Persona
from danswer.db.models import SearchSettings
from danswer.indexing.models import BaseChunk
from danswer.indexing.models import IndexingSetting
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import SearchType
from shared_configs.enums import RerankerProvider

View File

@@ -7,6 +7,22 @@ from sqlalchemy.orm import Session
from danswer.chat.models import SectionRelevancePiece
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import SearchType
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceChunk
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RerankMetricsContainer
from danswer.context.search.models import RetrievalMetricsContainer
from danswer.context.search.models import SearchQuery
from danswer.context.search.models import SearchRequest
from danswer.context.search.postprocessing.postprocessing import cleanup_chunks
from danswer.context.search.postprocessing.postprocessing import search_postprocessing
from danswer.context.search.preprocessing.preprocessing import retrieval_preprocessing
from danswer.context.search.retrieval.search_runner import retrieve_chunks
from danswer.context.search.utils import inference_section_from_chunks
from danswer.context.search.utils import relevant_sections_to_indices
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
@@ -16,22 +32,6 @@ from danswer.llm.answering.prune_and_merge import _merge_sections
from danswer.llm.answering.prune_and_merge import ChunkRange
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
from danswer.llm.interfaces import LLM
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchRequest
from danswer.search.postprocessing.postprocessing import cleanup_chunks
from danswer.search.postprocessing.postprocessing import search_postprocessing
from danswer.search.preprocessing.preprocessing import retrieval_preprocessing
from danswer.search.retrieval.search_runner import retrieve_chunks
from danswer.search.utils import inference_section_from_chunks
from danswer.search.utils import relevant_sections_to_indices
from danswer.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall

View File

@@ -9,19 +9,19 @@ from danswer.configs.app_configs import BLURB_SIZE
from danswer.configs.constants import RETURN_SEPARATOR
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.models import ChunkMetric
from danswer.context.search.models import InferenceChunk
from danswer.context.search.models import InferenceChunkUncleaned
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import MAX_METRICS_CONTENT
from danswer.context.search.models import RerankMetricsContainer
from danswer.context.search.models import SearchQuery
from danswer.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from danswer.llm.interfaces import LLM
from danswer.natural_language_processing.search_nlp_models import RerankingModel
from danswer.search.enums import LLMEvaluationType
from danswer.search.models import ChunkMetric
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceChunkUncleaned
from danswer.search.models import InferenceSection
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import SearchQuery
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall

View File

@@ -1,8 +1,8 @@
from sqlalchemy.orm import Session
from danswer.access.access import get_acl_for_user
from danswer.context.search.models import IndexFilters
from danswer.db.models import User
from danswer.search.models import IndexFilters
def build_access_filters_for_user(user: User | None, session: Session) -> list[str]:

View File

@@ -9,21 +9,25 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import HYBRID_ALPHA_KEYWORD
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
from danswer.context.search.models import BaseFilters
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import SearchQuery
from danswer.context.search.models import SearchRequest
from danswer.context.search.preprocessing.access_filters import (
build_access_filters_for_user,
)
from danswer.context.search.retrieval.search_runner import (
remove_stop_words_and_punctuation,
)
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.llm.interfaces import LLM
from danswer.natural_language_processing.search_nlp_models import QueryAnalysisModel
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.search.models import BaseFilters
from danswer.search.models import IndexFilters
from danswer.search.models import RerankingDetails
from danswer.search.models import SearchQuery
from danswer.search.models import SearchRequest
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.secondary_llm_flows.source_filter import extract_source_filter
from danswer.secondary_llm_flows.time_filter import extract_time_filter
from danswer.utils.logger import setup_logger

View File

@@ -6,6 +6,16 @@ from nltk.corpus import stopwords # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sqlalchemy.orm import Session
from danswer.context.search.models import ChunkMetric
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceChunk
from danswer.context.search.models import InferenceChunkUncleaned
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import MAX_METRICS_CONTENT
from danswer.context.search.models import RetrievalMetricsContainer
from danswer.context.search.models import SearchQuery
from danswer.context.search.postprocessing.postprocessing import cleanup_chunks
from danswer.context.search.utils import inference_section_from_chunks
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_multilingual_expansion
from danswer.document_index.interfaces import DocumentIndex
@@ -14,16 +24,6 @@ from danswer.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.search.models import ChunkMetric
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceChunkUncleaned
from danswer.search.models import InferenceSection
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.postprocessing.postprocessing import cleanup_chunks
from danswer.search.utils import inference_section_from_chunks
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel

View File

@@ -1,9 +1,9 @@
from typing import cast
from danswer.configs.constants import KV_SEARCH_SETTINGS
from danswer.context.search.models import SavedSearchSettings
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.search.models import SavedSearchSettings
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -2,12 +2,12 @@ from collections.abc import Sequence
from typing import TypeVar
from danswer.chat.models import SectionRelevancePiece
from danswer.context.search.models import InferenceChunk
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import SavedSearchDoc
from danswer.context.search.models import SavedSearchDocWithContent
from danswer.context.search.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.search.models import SavedSearchDoc
from danswer.search.models import SavedSearchDocWithContent
from danswer.search.models import SearchDoc
T = TypeVar(

View File

@@ -18,20 +18,30 @@ from slack_sdk.models.blocks.block_elements import ImageElement
from danswer.chat.models import DanswerQuote
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
from danswer.context.search.models import SavedSearchDoc
from danswer.danswerbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.formatting import format_slack_message
from danswer.danswerbot.slack.icons import source_to_github_img_link
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import build_continue_in_web_ui_id
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import remove_slack_text_interactions
from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
from danswer.search.models import SavedSearchDoc
from danswer.db.chat import get_chat_session_by_message_id
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.utils.text_processing import decode_escapes
from danswer.utils.text_processing import replace_whitespaces_w_space
@@ -101,12 +111,12 @@ def _split_text(text: str, limit: int = 3000) -> list[str]:
return chunks
def clean_markdown_link_text(text: str) -> str:
def _clean_markdown_link_text(text: str) -> str:
# Remove any newlines within the text
return text.replace("\n", " ").strip()
def build_qa_feedback_block(
def _build_qa_feedback_block(
message_id: int, feedback_reminder_id: str | None = None
) -> Block:
return ActionsBlock(
@@ -115,7 +125,6 @@ def build_qa_feedback_block(
ButtonElement(
action_id=LIKE_BLOCK_ACTION_ID,
text="👍 Helpful",
style="primary",
value=feedback_reminder_id,
),
ButtonElement(
@@ -155,7 +164,7 @@ def get_document_feedback_blocks() -> Block:
)
def build_doc_feedback_block(
def _build_doc_feedback_block(
message_id: int,
document_id: str,
document_rank: int,
@@ -182,7 +191,7 @@ def get_restate_blocks(
]
def build_documents_blocks(
def _build_documents_blocks(
documents: list[SavedSearchDoc],
message_id: int | None,
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
@@ -223,7 +232,7 @@ def build_documents_blocks(
feedback: ButtonElement | dict = {}
if message_id is not None:
feedback = build_doc_feedback_block(
feedback = _build_doc_feedback_block(
message_id=message_id,
document_id=d.document_id,
document_rank=rank,
@@ -241,7 +250,7 @@ def build_documents_blocks(
return section_blocks
def build_sources_blocks(
def _build_sources_blocks(
cited_documents: list[tuple[int, SavedSearchDoc]],
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
) -> list[Block]:
@@ -286,7 +295,7 @@ def build_sources_blocks(
+ ([days_ago_str] if days_ago_str else [])
)
document_title = clean_markdown_link_text(doc_sem_id)
document_title = _clean_markdown_link_text(doc_sem_id)
img_link = source_to_github_img_link(d.source_type)
section_blocks.append(
@@ -317,7 +326,50 @@ def build_sources_blocks(
return section_blocks
def build_quotes_block(
def _priority_ordered_documents_blocks(
answer: OneShotQAResponse,
) -> list[Block]:
docs_response = answer.docs if answer.docs else None
top_docs = docs_response.top_documents if docs_response else []
llm_doc_inds = answer.llm_selected_doc_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
if not priority_ordered_docs:
return []
document_blocks = _build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
if document_blocks:
document_blocks = [DividerBlock()] + document_blocks
return document_blocks
def _build_citations_blocks(
answer: OneShotQAResponse,
) -> list[Block]:
docs_response = answer.docs if answer.docs else None
top_docs = docs_response.top_documents if docs_response else []
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = _build_sources_blocks(cited_documents=cited_docs)
return citations_block
def _build_quotes_block(
quotes: list[DanswerQuote],
) -> list[Block]:
quote_lines: list[str] = []
@@ -359,58 +411,70 @@ def build_quotes_block(
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
def build_qa_response_blocks(
message_id: int | None,
answer: str | None,
quotes: list[DanswerQuote] | None,
source_filters: list[DocumentSource] | None,
time_cutoff: datetime | None,
favor_recent: bool,
def _build_qa_response_blocks(
answer: OneShotQAResponse,
skip_quotes: bool = False,
process_message_for_citations: bool = False,
skip_ai_feedback: bool = False,
feedback_reminder_id: str | None = None,
) -> list[Block]:
retrieval_info = answer.docs
if not retrieval_info:
# This should not happen, even with no docs retrieved, there is still info returned
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
quotes = answer.quotes.quotes if answer.quotes else None
if DISABLE_GENERATIVE_AI:
return []
quotes_blocks: list[Block] = []
filter_block: Block | None = None
if time_cutoff or favor_recent or source_filters:
if (
retrieval_info.applied_time_cutoff
or retrieval_info.recency_bias_multiplier > 1
or retrieval_info.applied_source_filters
):
filter_text = "Filters: "
if source_filters:
sources_str = ", ".join([s.value for s in source_filters])
if retrieval_info.applied_source_filters:
sources_str = ", ".join(
[s.value for s in retrieval_info.applied_source_filters]
)
filter_text += f"`Sources in [{sources_str}]`"
if time_cutoff or favor_recent:
if (
retrieval_info.applied_time_cutoff
or retrieval_info.recency_bias_multiplier > 1
):
filter_text += " and "
if time_cutoff is not None:
time_str = time_cutoff.strftime("%b %d, %Y")
if retrieval_info.applied_time_cutoff is not None:
time_str = retrieval_info.applied_time_cutoff.strftime("%b %d, %Y")
filter_text += f"`Docs Updated >= {time_str}` "
if favor_recent:
if time_cutoff is not None:
if retrieval_info.recency_bias_multiplier > 1:
if retrieval_info.applied_time_cutoff is not None:
filter_text += "+ "
filter_text += "`Prioritize Recently Updated Docs`"
filter_block = SectionBlock(text=f"_{filter_text}_")
if not answer:
if not formatted_answer:
answer_blocks = [
SectionBlock(
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
)
]
else:
answer_processed = decode_escapes(remove_slack_text_interactions(answer))
answer_processed = decode_escapes(
remove_slack_text_interactions(formatted_answer)
)
if process_message_for_citations:
answer_processed = _process_citations_for_slack(answer_processed)
answer_blocks = [
SectionBlock(text=text) for text in _split_text(answer_processed)
]
if quotes:
quotes_blocks = build_quotes_block(quotes)
quotes_blocks = _build_quotes_block(quotes)
# if no quotes OR `build_quotes_block()` did not give back any blocks
# if no quotes OR `_build_quotes_block()` did not give back any blocks
if not quotes_blocks:
quotes_blocks = [
SectionBlock(
@@ -425,20 +489,37 @@ def build_qa_response_blocks(
response_blocks.extend(answer_blocks)
if message_id is not None and not skip_ai_feedback:
response_blocks.append(
build_qa_feedback_block(
message_id=message_id, feedback_reminder_id=feedback_reminder_id
)
)
if not skip_quotes:
response_blocks.extend(quotes_blocks)
return response_blocks
def build_follow_up_block(message_id: int | None) -> ActionsBlock:
def _build_continue_in_web_ui_block(
tenant_id: str | None,
message_id: int | None,
) -> Block:
if message_id is None:
raise ValueError("No message id provided to build continue in web ui block")
with get_session_with_tenant(tenant_id) as db_session:
chat_session = get_chat_session_by_message_id(
db_session=db_session,
message_id=message_id,
)
return ActionsBlock(
block_id=build_continue_in_web_ui_id(message_id),
elements=[
ButtonElement(
action_id=CONTINUE_IN_WEB_UI_ACTION_ID,
text="Continue Chat in Danswer!",
style="primary",
url=f"{WEB_DOMAIN}/chat?slackChatId={chat_session.id}",
),
],
)
def _build_follow_up_block(message_id: int | None) -> ActionsBlock:
return ActionsBlock(
block_id=build_feedback_id(message_id) if message_id is not None else None,
elements=[
@@ -483,3 +564,77 @@ def build_follow_up_resolved_blocks(
]
)
return [text_block, button_block]
def build_slack_response_blocks(
tenant_id: str | None,
message_info: SlackMessageInfo,
answer: OneShotQAResponse,
persona: Persona | None,
channel_conf: ChannelConfig | None,
use_citations: bool,
feedback_reminder_id: str | None,
skip_ai_feedback: bool = False,
) -> list[Block]:
"""
This function is a top level function that builds all the blocks for the Slack response.
It also handles combining all the blocks together.
"""
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(
message_info.thread_messages[-1].message, message_info.is_bot_msg
)
answer_blocks = _build_qa_response_blocks(
answer=answer,
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
)
web_follow_up_block = []
if channel_conf and channel_conf.get("show_continue_in_web_ui"):
web_follow_up_block.append(
_build_continue_in_web_ui_block(
tenant_id=tenant_id,
message_id=answer.chat_message_id,
)
)
follow_up_block = []
if channel_conf and channel_conf.get("follow_up_tags") is not None:
follow_up_block.append(
_build_follow_up_block(message_id=answer.chat_message_id)
)
ai_feedback_block = []
if answer.chat_message_id is not None and not skip_ai_feedback:
ai_feedback_block.append(
_build_qa_feedback_block(
message_id=answer.chat_message_id,
feedback_reminder_id=feedback_reminder_id,
)
)
citations_blocks = []
document_blocks = []
if use_citations:
# if citations are enabled, only show cited documents
citations_blocks = _build_citations_blocks(answer)
else:
document_blocks = _priority_ordered_documents_blocks(answer)
citations_divider = [DividerBlock()] if citations_blocks else []
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
all_blocks = (
restate_question_block
+ answer_blocks
+ ai_feedback_block
+ citations_divider
+ citations_blocks
+ document_blocks
+ buttons_divider
+ web_follow_up_block
+ follow_up_block
)
return all_blocks

View File

@@ -2,6 +2,7 @@ from enum import Enum
LIKE_BLOCK_ACTION_ID = "feedback-like"
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui"
FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button"
FOLLOWUP_BUTTON_ACTION_ID = "followup-button"

View File

@@ -28,7 +28,7 @@ from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import decompose_action_id
from danswer.danswerbot.slack.utils import fetch_group_ids_from_names
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_feedback_visibility
from danswer.danswerbot.slack.utils import read_slack_thread
@@ -267,7 +267,7 @@ def handle_followup_button(
tag_names = slack_channel_config.channel_config.get("follow_up_tags")
remaining = None
if tag_names:
tag_ids, remaining = fetch_user_ids_from_emails(
tag_ids, remaining = fetch_slack_user_ids_from_emails(
tag_names, client.web_client
)
if remaining:

View File

@@ -13,7 +13,7 @@ from danswer.danswerbot.slack.handlers.handle_standard_answers import (
handle_standard_answers,
)
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import slack_usage_report
@@ -184,7 +184,7 @@ def handle_message(
send_to: list[str] | None = None
missing_users: list[str] | None = None
if respond_member_group_list:
send_to, missing_ids = fetch_user_ids_from_emails(
send_to, missing_ids = fetch_slack_user_ids_from_emails(
respond_member_group_list, client
)

View File

@@ -7,7 +7,6 @@ from typing import TypeVar
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import SectionBlock
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
@@ -21,12 +20,11 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
from danswer.danswerbot.slack.blocks import build_sources_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.formatting import format_slack_message
from danswer.context.search.enums import OptionalSearchSetting
from danswer.context.search.models import BaseFilters
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.danswerbot.slack.blocks import build_slack_response_blocks
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread
@@ -48,10 +46,6 @@ from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.search.enums import OptionalSearchSetting
from danswer.search.models import BaseFilters
from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails
from danswer.utils.logger import DanswerLoggingAdapter
@@ -411,62 +405,16 @@ def handle_regular_answer(
)
return True
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
answer_blocks = build_qa_response_blocks(
message_id=answer.chat_message_id,
answer=formatted_answer,
quotes=answer.quotes.quotes if answer.quotes else None,
source_filters=retrieval_info.applied_source_filters,
time_cutoff=retrieval_info.applied_time_cutoff,
favor_recent=retrieval_info.recency_bias_multiplier > 1,
# currently Personas don't support quotes
# if citations are enabled, also don't use quotes
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
all_blocks = build_slack_response_blocks(
tenant_id=tenant_id,
message_info=message_info,
answer=answer,
persona=persona,
channel_conf=channel_conf,
use_citations=use_citations,
feedback_reminder_id=feedback_reminder_id,
)
# Get the chunks fed to the LLM only, then fill with other docs
llm_doc_inds = answer.llm_selected_doc_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = []
citations_block = []
# if citations are enabled, only show cited documents
if use_citations:
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = build_sources_blocks(cited_documents=cited_docs)
elif priority_ordered_docs:
document_blocks = build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
document_blocks = [DividerBlock()] + document_blocks
all_blocks = (
restate_question_block + answer_blocks + citations_block + document_blocks
)
if channel_conf and channel_conf.get("follow_up_tags") is not None:
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
try:
respond_in_thread(
client=client,

View File

@@ -27,6 +27,7 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.context.search.retrieval.search_runner import download_nltk_data
from danswer.danswerbot.slack.config import get_slack_channel_config_for_bot_and_channel
from danswer.danswerbot.slack.config import MAX_TENANTS_PER_POD
from danswer.danswerbot.slack.config import TENANT_ACQUISITION_INTERVAL
@@ -75,7 +76,6 @@ from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage
from danswer.redis.redis_pool import get_redis_client
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable

View File

@@ -3,9 +3,9 @@ import random
import re
import string
import time
import uuid
from typing import Any
from typing import cast
from typing import Optional
from retry import retry
from slack_sdk import WebClient
@@ -216,6 +216,13 @@ def build_feedback_id(
return unique_prefix + ID_SEPARATOR + feedback_id
def build_continue_in_web_ui_id(
message_id: int,
) -> str:
unique_prefix = str(uuid.uuid4())[:10]
return unique_prefix + ID_SEPARATOR + str(message_id)
def decompose_action_id(feedback_id: str) -> tuple[int, str | None, int | None]:
"""Decompose into query_id, document_id, document_rank, see above function"""
try:
@@ -313,7 +320,7 @@ def get_channel_name_from_id(
raise e
def fetch_user_ids_from_emails(
def fetch_slack_user_ids_from_emails(
user_emails: list[str], client: WebClient
) -> tuple[list[str], list[str]]:
user_ids: list[str] = []
@@ -522,7 +529,7 @@ class SlackRateLimiter:
self.last_reset_time = time.time()
def notify(
self, client: WebClient, channel: str, position: int, thread_ts: Optional[str]
self, client: WebClient, channel: str, position: int, thread_ts: str | None
) -> None:
respond_in_thread(
client=client,

View File

@@ -2,6 +2,7 @@ import uuid
from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
@@ -45,14 +46,16 @@ def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]:
]
def fetch_user_for_api_key(hashed_api_key: str, db_session: Session) -> User | None:
api_key = db_session.scalar(
select(ApiKey).where(ApiKey.hashed_api_key == hashed_api_key)
async def fetch_user_for_api_key(
hashed_api_key: str, async_db_session: AsyncSession
) -> User | None:
"""NOTE: this is async, since it's used during auth
(which is necessarily async due to FastAPI Users)"""
return await async_db_session.scalar(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
)
if api_key is None:
return None
return db_session.scalar(select(User).where(User.id == api_key.user_id)) # type: ignore
def get_api_key_fake_email(

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from datetime import timedelta
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
@@ -18,6 +19,9 @@ from danswer.auth.schemas import UserRole
from danswer.chat.models import DocumentRelevance
from danswer.configs.chat_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.context.search.models import RetrievalDocs
from danswer.context.search.models import SavedSearchDoc
from danswer.context.search.models import SearchDoc as ServerSearchDoc
from danswer.db.models import ChatMessage
from danswer.db.models import ChatMessage__SearchDoc
from danswer.db.models import ChatSession
@@ -27,13 +31,11 @@ from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.persona import get_best_persona_id_for_user
from danswer.db.pg_file_store import delete_lobj_by_name
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.search.models import RetrievalDocs
from danswer.search.models import SavedSearchDoc
from danswer.search.models import SearchDoc as ServerSearchDoc
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.utils.logger import setup_logger
@@ -250,6 +252,50 @@ def create_chat_session(
return chat_session
def duplicate_chat_session_for_user_from_slack(
db_session: Session,
user: User | None,
chat_session_id: UUID,
) -> ChatSession:
"""
This takes a chat session id for a session in Slack and:
- Creates a new chat session in the DB
- Tries to copy the persona from the original chat session
(if it is available to the user clicking the button)
- Sets the user to the given user (if provided)
"""
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
user_id=None, # Ignore user permissions for this
db_session=db_session,
)
if not chat_session:
raise HTTPException(status_code=400, detail="Invalid Chat Session ID provided")
# This enforces permissions and sets a default
new_persona_id = get_best_persona_id_for_user(
db_session=db_session,
user=user,
persona_id=chat_session.persona_id,
)
return create_chat_session(
db_session=db_session,
user_id=user.id if user else None,
persona_id=new_persona_id,
# Set this to empty string so the frontend will force a rename
description="",
llm_override=chat_session.llm_override,
prompt_override=chat_session.prompt_override,
# Chat sessions from Slack should put people in the chat UI, not the search
one_shot=False,
# Chat is in UI now so this is false
danswerbot_flow=False,
# Maybe we want this in the future to track if it was created from Slack
slack_thread_id=None,
)
def update_chat_session(
db_session: Session,
user_id: UUID | None,
@@ -336,6 +382,28 @@ def get_chat_message(
return chat_message
def get_chat_session_by_message_id(
db_session: Session,
message_id: int,
) -> ChatSession:
"""
Should only be used for Slack
Get the chat session associated with a specific message ID
Note: this ignores permission checks.
"""
stmt = select(ChatMessage).where(ChatMessage.id == message_id)
result = db_session.execute(stmt)
chat_message = result.scalar_one_or_none()
if chat_message is None:
raise ValueError(
f"Unable to find chat session associated with message ID: {message_id}"
)
return chat_message.chat_session
def get_chat_messages_by_sessions(
chat_session_ids: list[UUID],
user_id: UUID | None,
@@ -355,6 +423,44 @@ def get_chat_messages_by_sessions(
return db_session.execute(stmt).scalars().all()
def add_chats_to_session_from_slack_thread(
db_session: Session,
slack_chat_session_id: UUID,
new_chat_session_id: UUID,
) -> None:
new_root_message = get_or_create_root_message(
chat_session_id=new_chat_session_id,
db_session=db_session,
)
for chat_message in get_chat_messages_by_sessions(
chat_session_ids=[slack_chat_session_id],
user_id=None, # Ignore user permissions for this
db_session=db_session,
skip_permission_check=True,
):
if chat_message.message_type == MessageType.SYSTEM:
continue
# Duplicate the message
new_root_message = create_new_chat_message(
db_session=db_session,
chat_session_id=new_chat_session_id,
parent_message=new_root_message,
message=chat_message.message,
files=chat_message.files,
rephrased_query=chat_message.rephrased_query,
error=chat_message.error,
citations=chat_message.citations,
reference_docs=chat_message.search_docs,
tool_call=chat_message.tool_call,
prompt_id=chat_message.prompt_id,
token_count=chat_message.token_count,
message_type=chat_message.message_type,
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
)
def get_search_docs_for_chat_message(
chat_message_id: int, db_session: Session
) -> list[SearchDoc]:

View File

@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
from danswer.db.enums import IndexingMode
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
@@ -311,3 +312,25 @@ def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int)
# If this changes, we need to update this function.
cc_pair.last_time_external_group_sync = datetime.now(timezone.utc)
db_session.commit()
def mark_ccpair_with_indexing_trigger(
cc_pair_id: int, indexing_mode: IndexingMode | None, db_session: Session
) -> None:
"""indexing_mode sets a field which will be picked up by a background task
to trigger indexing. Set to None to disable the trigger."""
try:
cc_pair = db_session.execute(
select(ConnectorCredentialPair)
.where(ConnectorCredentialPair.id == cc_pair_id)
.with_for_update()
).scalar_one()
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
cc_pair.indexing_trigger = indexing_mode
db_session.commit()
except Exception:
db_session.rollback()
raise

View File

@@ -324,8 +324,11 @@ def associate_default_cc_pair(db_session: Session) -> None:
def _relate_groups_to_cc_pair__no_commit(
db_session: Session,
cc_pair_id: int,
user_group_ids: list[int],
user_group_ids: list[int] | None = None,
) -> None:
if not user_group_ids:
return
for group_id in user_group_ids:
db_session.add(
UserGroup__ConnectorCredentialPair(
@@ -402,12 +405,11 @@ def add_credential_to_connector(
db_session.flush() # make sure the association has an id
db_session.refresh(association)
if groups and access_type != AccessType.SYNC:
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
user_group_ids=groups,
)
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
user_group_ids=groups,
)
db_session.commit()

View File

@@ -5,6 +5,7 @@ class IndexingStatus(str, PyEnum):
NOT_STARTED = "not_started"
IN_PROGRESS = "in_progress"
SUCCESS = "success"
CANCELED = "canceled"
FAILED = "failed"
COMPLETED_WITH_ERRORS = "completed_with_errors"
@@ -12,11 +13,17 @@ class IndexingStatus(str, PyEnum):
terminal_states = {
IndexingStatus.SUCCESS,
IndexingStatus.COMPLETED_WITH_ERRORS,
IndexingStatus.CANCELED,
IndexingStatus.FAILED,
}
return self in terminal_states
class IndexingMode(str, PyEnum):
UPDATE = "update"
REINDEX = "reindex"
# these may differ in the future, which is why we're okay with this duplication
class DeletionStatus(str, PyEnum):
NOT_STARTED = "not_started"

View File

@@ -225,6 +225,28 @@ def mark_attempt_partially_succeeded(
raise
def mark_attempt_canceled(
index_attempt_id: int,
db_session: Session,
reason: str = "Unknown",
) -> None:
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.with_for_update()
).scalar_one()
if not attempt.time_started:
attempt.time_started = datetime.now(timezone.utc)
attempt.status = IndexingStatus.CANCELED
attempt.error_msg = reason
db_session.commit()
except Exception:
db_session.rollback()
raise
def mark_attempt_failed(
index_attempt_id: int,
db_session: Session,

View File

@@ -42,7 +42,7 @@ from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.db.enums import AccessType
from danswer.db.enums import AccessType, IndexingMode
from danswer.configs.constants import NotificationType
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.constants import TokenRateLimitScope
@@ -57,7 +57,7 @@ from danswer.utils.special_types import JSON_ro
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.search.enums import RecencyBiasSetting
from danswer.context.search.enums import RecencyBiasSetting
from danswer.utils.encryption import decrypt_bytes_to_string
from danswer.utils.encryption import encrypt_string_to_bytes
from danswer.utils.headers import HeaderItemDict
@@ -438,6 +438,10 @@ class ConnectorCredentialPair(Base):
total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0)
indexing_trigger: Mapped[IndexingMode | None] = mapped_column(
Enum(IndexingMode, native_enum=False), nullable=True
)
connector: Mapped["Connector"] = relationship(
"Connector", back_populates="credentials"
)
@@ -1480,6 +1484,7 @@ class ChannelConfig(TypedDict):
# If None then no follow up
# If empty list, follow up with no tags
follow_up_tags: NotRequired[list[str]]
show_continue_in_web_ui: NotRequired[bool] # defaults to False
class SlackBotResponseType(str, PyEnum):

View File

@@ -20,6 +20,7 @@ from danswer.auth.schemas import UserRole
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from danswer.context.search.enums import RecencyBiasSetting
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet
@@ -33,7 +34,6 @@ from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup
from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.utils.logger import setup_logger
@@ -113,6 +113,31 @@ def fetch_persona_by_id(
return persona
def get_best_persona_id_for_user(
db_session: Session, user: User | None, persona_id: int | None = None
) -> int | None:
if persona_id is not None:
stmt = select(Persona).where(Persona.id == persona_id).distinct()
stmt = _add_user_filters(
stmt=stmt,
user=user,
# We don't want to filter by editable here, we just want to see if the
# persona is usable by the user
get_editable=False,
)
persona = db_session.scalars(stmt).one_or_none()
if persona:
return persona.id
# If the persona is not found, or the slack bot is using doc sets instead of personas,
# we need to find the best persona for the user
# This is the persona with the highest display priority that the user has access to
stmt = select(Persona).order_by(Persona.display_priority.desc()).distinct()
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=True)
persona = db_session.scalars(stmt).one_or_none()
return persona.id if persona else None
def _get_persona_by_name(
persona_name: str, user: User | None, db_session: Session
) -> Persona | None:
@@ -160,7 +185,7 @@ def create_update_persona(
"persona_id": persona_id,
"user": user,
"db_session": db_session,
**create_persona_request.dict(exclude={"users", "groups"}),
**create_persona_request.model_dump(exclude={"users", "groups"}),
}
persona = upsert_persona(**persona_data)
@@ -390,6 +415,9 @@ def upsert_prompt(
return prompt
# NOTE: This operation cannot update persona configuration options that
# are core to the persona, such as its display priority and
# whether or not the assistant is a built-in / default assistant
def upsert_persona(
user: User | None,
name: str,
@@ -458,7 +486,7 @@ def upsert_persona(
validate_persona_tools(tools)
if persona:
if not builtin_persona and persona.builtin_persona:
if persona.builtin_persona and not builtin_persona:
raise ValueError("Cannot update builtin persona with non-builtin.")
# this checks if the user has permission to edit the persona
@@ -474,7 +502,6 @@ def upsert_persona(
persona.llm_relevance_filter = llm_relevance_filter
persona.llm_filter_extraction = llm_filter_extraction
persona.recency_bias = recency_bias
persona.builtin_persona = builtin_persona
persona.llm_model_provider_override = llm_model_provider_override
persona.llm_model_version_override = llm_model_version_override
persona.starter_messages = starter_messages
@@ -484,10 +511,8 @@ def upsert_persona(
persona.icon_shape = icon_shape
if remove_image or uploaded_image_id:
persona.uploaded_image_id = uploaded_image_id
persona.display_priority = display_priority
persona.is_visible = is_visible
persona.search_start_date = search_start_date
persona.is_default_persona = is_default_persona
persona.category_id = category_id
# Do not delete any associations manually added unless
# a new updated list is provided
@@ -733,6 +758,8 @@ def get_prompt_by_name(
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(Prompt.user_id == user.id)
# Order by ID to ensure consistent result when multiple prompts exist
stmt = stmt.order_by(Prompt.id).limit(1)
result = db_session.execute(stmt).scalar_one_or_none()
return result

View File

@@ -12,6 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.context.search.models import SavedSearchSettings
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
@@ -21,7 +22,6 @@ from danswer.db.models import SearchSettings
from danswer.indexing.models import IndexingSetting
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from danswer.search.models import SavedSearchSettings
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)
@@ -143,6 +143,25 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
return latest_settings
def get_active_search_settings(db_session: Session) -> list[SearchSettings]:
"""Returns active search settings. The first entry will always be the current search
settings. If there are new search settings that are being migrated to, those will be
the second entry."""
search_settings_list: list[SearchSettings] = []
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings_list.append(primary_search_settings)
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings_list.append(secondary_search_settings)
return search_settings_list
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
query = select(SearchSettings).order_by(SearchSettings.id.desc())
result = db_session.execute(query)

View File

@@ -5,6 +5,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.context.search.enums import RecencyBiasSetting
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
@@ -15,7 +16,6 @@ from danswer.db.models import User
from danswer.db.persona import get_default_prompt
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import upsert_persona
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.errors import EERequiredError
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,

View File

@@ -103,17 +103,6 @@ def list_users(
return db_session.scalars(stmt).unique().all()
def get_users_by_emails(
db_session: Session, emails: list[str]
) -> tuple[list[User], list[str]]:
# Use distinct to avoid duplicates
stmt = select(User).filter(User.email.in_(emails)) # type: ignore
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
found_users_emails = [user.email for user in found_users]
missing_user_emails = [email for email in emails if email not in found_users_emails]
return found_users, missing_user_emails
def get_user_by_email(email: str, db_session: Session) -> User | None:
user = (
db_session.query(User)
@@ -128,7 +117,7 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None:
return db_session.query(User).filter(User.id == user_id).first() # type: ignore
def _generate_non_web_slack_user(email: str) -> User:
def _generate_slack_user(email: str) -> User:
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
@@ -149,13 +138,29 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
db_session.commit()
return user
user = _generate_non_web_slack_user(email=email)
user = _generate_slack_user(email=email)
db_session.add(user)
db_session.commit()
return user
def _generate_non_web_permissioned_user(email: str) -> User:
def _get_users_by_emails(
db_session: Session, lower_emails: list[str]
) -> tuple[list[User], list[str]]:
stmt = select(User).filter(func.lower(User.email).in_(lower_emails)) # type: ignore
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
# Extract found emails and convert to lowercase to avoid case sensitivity issues
found_users_emails = [user.email.lower() for user in found_users]
# Separate emails for users that were not found
missing_user_emails = [
email for email in lower_emails if email not in found_users_emails
]
return found_users, missing_user_emails
def _generate_ext_permissioned_user(email: str) -> User:
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
@@ -169,12 +174,12 @@ def _generate_non_web_permissioned_user(email: str) -> User:
def batch_add_ext_perm_user_if_not_exists(
db_session: Session, emails: list[str]
) -> list[User]:
emails = [email.lower() for email in emails]
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
lower_emails = [email.lower() for email in emails]
found_users, missing_lower_emails = _get_users_by_emails(db_session, lower_emails)
new_users: list[User] = []
for email in missing_user_emails:
new_users.append(_generate_non_web_permissioned_user(email=email))
for email in missing_lower_emails:
new_users.append(_generate_ext_permissioned_user(email=email))
db_session.add_all(new_users)
db_session.commit()

View File

@@ -3,10 +3,10 @@ import uuid
from sqlalchemy.orm import Session
from danswer.context.search.models import InferenceChunk
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.indexing.models import IndexChunk
from danswer.search.models import InferenceChunk
DEFAULT_BATCH_SIZE = 30

View File

@@ -4,9 +4,9 @@ from datetime import datetime
from typing import Any
from danswer.access.models import DocumentAccess
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceChunkUncleaned
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned
from shared_configs.model_server_models import Embedding

View File

@@ -11,6 +11,8 @@ import httpx
from retry import retry
from danswer.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceChunkUncleaned
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.document_index.vespa.shared_utils.utils import get_vespa_http_client
from danswer.document_index.vespa.shared_utils.vespa_request_builders import (
@@ -44,8 +46,6 @@ from danswer.document_index.vespa_constants import SOURCE_LINKS
from danswer.document_index.vespa_constants import SOURCE_TYPE
from danswer.document_index.vespa_constants import TITLE
from danswer.document_index.vespa_constants import YQL_BASE
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel

View File

@@ -22,6 +22,8 @@ from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
from danswer.configs.chat_configs import VESPA_SEARCHER_THREADS
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceChunkUncleaned
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentInsertionRecord
from danswer.document_index.interfaces import UpdateRequest
@@ -68,8 +70,6 @@ from danswer.document_index.vespa_constants import VESPA_TIMEOUT
from danswer.document_index.vespa_constants import YQL_BASE
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.key_value_store.factory import get_kv_store
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT

View File

@@ -3,6 +3,7 @@ from datetime import timedelta
from datetime import timezone
from danswer.configs.constants import INDEX_SEPARATOR
from danswer.context.search.models import IndexFilters
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.document_index.vespa_constants import ACCESS_CONTROL_LIST
from danswer.document_index.vespa_constants import CHUNK_ID
@@ -13,7 +14,6 @@ from danswer.document_index.vespa_constants import HIDDEN
from danswer.document_index.vespa_constants import METADATA_LIST
from danswer.document_index.vespa_constants import SOURCE_TYPE
from danswer.document_index.vespa_constants import TENANT_ID
from danswer.search.models import IndexFilters
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -295,7 +295,7 @@ def pptx_to_text(file: IO[Any]) -> str:
def xlsx_to_text(file: IO[Any]) -> str:
workbook = openpyxl.load_workbook(file)
workbook = openpyxl.load_workbook(file, read_only=True)
text_content = []
for sheet in workbook.worksheets:
sheet_string = "\n".join(

View File

@@ -14,6 +14,7 @@ from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_text
from danswer.utils.text_processing import shared_precompare_cleanup
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
@@ -220,9 +221,20 @@ class Chunker:
mini_chunk_texts=self._get_mini_chunk_texts(text),
)
for section in document.sections:
section_text = section.text
for section_idx, section in enumerate(document.sections):
section_text = clean_text(section.text)
section_link_text = section.link or ""
# If there is no useful content, not even the title, just drop it
if not section_text and (not document.title or section_idx > 0):
# If a section is empty and the document has no title, we can just drop it. We return a list of
# DocAwareChunks where each one contains the necessary information needed down the line for indexing.
# There is no concern about dropping whole documents from this list, it should not cause any indexing failures.
logger.warning(
f"Skipping section {section.text} from document "
f"{document.semantic_identifier} due to empty text after cleaning "
f" with link {section_link_text}"
)
continue
section_token_count = len(self.tokenizer.tokenize(section_text))
@@ -238,31 +250,26 @@ class Chunker:
split_texts = self.chunk_splitter.split_text(section_text)
for i, split_text in enumerate(split_texts):
split_token_count = len(self.tokenizer.tokenize(split_text))
if STRICT_CHUNK_TOKEN_LIMIT:
split_token_count = len(self.tokenizer.tokenize(split_text))
if split_token_count > content_token_limit:
# Further split the oversized chunk
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
)
for i, small_chunk in enumerate(smaller_chunks):
chunks.append(
_create_chunk(
text=small_chunk,
links={0: section_link_text},
is_continuation=(i != 0),
)
)
else:
if (
STRICT_CHUNK_TOKEN_LIMIT
and
# Tokenizer only runs if STRICT_CHUNK_TOKEN_LIMIT is true
len(self.tokenizer.tokenize(split_text)) > content_token_limit
):
# If STRICT_CHUNK_TOKEN_LIMIT is true, manually check
# the token count of each split text to ensure it is
# not larger than the content_token_limit
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
)
for i, small_chunk in enumerate(smaller_chunks):
chunks.append(
_create_chunk(
text=split_text,
text=small_chunk,
links={0: section_link_text},
is_continuation=(i != 0),
)
)
else:
chunks.append(
_create_chunk(
@@ -354,6 +361,10 @@ class Chunker:
return normal_chunks
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
"""
Takes in a list of documents and chunks them into smaller chunks for indexing
while persisting the document metadata.
"""
final_chunks: list[DocAwareChunk] = []
for document in documents:
if self.callback:

View File

@@ -233,6 +233,8 @@ class Answer:
# DEBUG: good breakpoint
stream = self.llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tool_choice=(

View File

@@ -58,8 +58,8 @@ class AnswerPromptBuilder:
user_message: HumanMessage,
message_history: list[PreviousMessage],
llm_config: LLMConfig,
raw_user_text: str,
single_message_history: str | None = None,
raw_user_text: str | None = None,
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)
@@ -89,11 +89,7 @@ class AnswerPromptBuilder:
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
self.raw_user_message = (
HumanMessage(content=raw_user_text)
if raw_user_text is not None
else user_message
)
self.raw_user_message = raw_user_text
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:

View File

@@ -3,6 +3,7 @@ from langchain.schema.messages import SystemMessage
from danswer.chat.models import LlmDoc
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.context.search.models import InferenceChunk
from danswer.db.models import Persona
from danswer.db.persona import get_default_prompt__read_only
from danswer.db.search_settings import get_multilingual_expansion
@@ -29,7 +30,6 @@ from danswer.prompts.token_counts import (
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -2,45 +2,15 @@ from langchain.schema.messages import HumanMessage
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import LANGUAGE_HINT
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.context.search.models import InferenceChunk
from danswer.db.search_settings import get_multilingual_expansion
from danswer.llm.answering.models import PromptConfig
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import build_complete_context_str
from danswer.search.models import InferenceChunk
def _build_weak_llm_quotes_prompt(
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
) -> HumanMessage:
"""Since Danswer supports a variety of LLMs, this less demanding prompt is provided
as an option to use with weaker LLMs such as small version, low float precision, quantized,
or distilled models. It only uses one context document and has very weak requirements of
output format.
"""
context_block = ""
if context_docs:
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs[0].content)
prompt_str = WEAK_LLM_PROMPT.format(
system_prompt=prompt.system_prompt,
context_block=context_block,
task_prompt=prompt.task_prompt,
user_query=question,
)
if prompt.datetime_aware:
prompt_str = add_date_time_to_prompt(prompt_str=prompt_str)
return HumanMessage(content=prompt_str)
def _build_strong_llm_quotes_prompt(
@@ -81,15 +51,9 @@ def build_quotes_user_message(
history_str: str,
prompt: PromptConfig,
) -> HumanMessage:
prompt_builder = (
_build_weak_llm_quotes_prompt
if QA_PROMPT_OVERRIDE == "weak"
else _build_strong_llm_quotes_prompt
)
query, _ = message_to_prompt_and_imgs(message)
return prompt_builder(
return _build_strong_llm_quotes_prompt(
question=query,
context_docs=context_docs,
history_str=history_str,

View File

@@ -10,6 +10,8 @@ from danswer.chat.models import (
)
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.context.search.models import InferenceChunk
from danswer.context.search.models import InferenceSection
from danswer.llm.answering.models import ContextualPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
@@ -17,8 +19,6 @@ from danswer.llm.interfaces import LLMConfig
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.tools.tool_implementations.search.search_utils import section_to_dict
from danswer.utils.logger import setup_logger

View File

@@ -13,6 +13,9 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
QuotesProcessor,
)
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.utils.logger import setup_logger
logger = setup_logger()
class AnswerResponseHandler(abc.ABC):
@@ -48,6 +51,9 @@ class CitationResponseHandler(AnswerResponseHandler):
self.processed_text = ""
self.citations: list[CitationInfo] = []
# TODO remove this after citation issue is resolved
logger.debug(f"Document to ranking map {self.doc_id_to_rank_map}")
def handle_response_part(
self,
response_item: BaseMessage | None,

View File

@@ -67,9 +67,9 @@ class CitationProcessor:
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
citation_pattern = r"\[(\d+)\]"
citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc.
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_pattern = r"(\[+\d*$)" # [1, [, [[, [[2, etc.
possible_citation_found = re.search(
possible_citation_pattern, self.curr_segment
)
@@ -77,13 +77,15 @@ class CitationProcessor:
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
self.current_citations = []
result = "" # Initialize result here
result = ""
if citations_found and not in_code_block(self.llm_out):
last_citation_end = 0
length_to_add = 0
while len(citations_found) > 0:
citation = citations_found.pop(0)
numerical_value = int(citation.group(1))
numerical_value = int(
next(group for group in citation.groups() if group is not None)
)
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
@@ -131,14 +133,6 @@ class CitationProcessor:
link = context_llm_doc.link
# Replace the citation in the current segment
start, end = citation.span()
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[{target_citation_num}]"
+ self.curr_segment[end + length_to_add :]
)
self.past_cite_count = len(self.llm_out)
self.current_citations.append(target_citation_num)
@@ -149,6 +143,7 @@ class CitationProcessor:
document_id=context_llm_doc.document_id,
)
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (

View File

@@ -12,9 +12,9 @@ from danswer.chat.models import DanswerQuote
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.context.search.models import InferenceChunk
from danswer.prompts.constants import ANSWER_PAT
from danswer.prompts.constants import QUOTE_PAT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import clean_up_code_blocks

View File

@@ -3,7 +3,7 @@ from collections.abc import Sequence
from pydantic import BaseModel
from danswer.chat.models import LlmDoc
from danswer.search.models import InferenceChunk
from danswer.context.search.models import InferenceChunk
class DocumentIdOrderMapping(BaseModel):

View File

@@ -62,7 +62,7 @@ class ToolResponseHandler:
llm_call.force_use_tool.args
if llm_call.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=llm_call.prompt_builder.get_user_message_content(),
query=llm_call.prompt_builder.raw_user_message,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
force_run=True,
@@ -76,7 +76,7 @@ class ToolResponseHandler:
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=llm_call.tools,
query=llm_call.prompt_builder.get_user_message_content(),
query=llm_call.prompt_builder.raw_user_message,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
)
@@ -95,7 +95,7 @@ class ToolResponseHandler:
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=llm_call.prompt_builder.raw_message_history,
query=llm_call.prompt_builder.get_user_message_content(),
query=llm_call.prompt_builder.raw_user_message,
llm=llm,
)
if available_tools_and_args

View File

@@ -26,7 +26,9 @@ from langchain_core.messages.tool import ToolMessage
from langchain_core.prompt_values import PromptValue
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
from danswer.configs.model_configs import (
DISABLE_LITELLM_STREAMING,
)
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import LITELLM_EXTRA_BODY
from danswer.llm.interfaces import LLM
@@ -161,7 +163,9 @@ def _convert_delta_to_message_chunk(
if role == "user":
return HumanMessageChunk(content=content)
elif role == "assistant":
# NOTE: if tool calls are present, then it's an assistant.
# In Ollama, the role will be None for tool-calls
elif role == "assistant" or tool_calls:
if tool_calls:
tool_call = tool_calls[0]
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
@@ -236,6 +240,7 @@ class DefaultMultiLLM(LLM):
custom_config: dict[str, str] | None = None,
extra_headers: dict[str, str] | None = None,
extra_body: dict | None = LITELLM_EXTRA_BODY,
model_kwargs: dict[str, Any] | None = None,
long_term_logger: LongTermLogger | None = None,
):
self._timeout = timeout
@@ -268,7 +273,7 @@ class DefaultMultiLLM(LLM):
for k, v in custom_config.items():
os.environ[k] = v
model_kwargs: dict[str, Any] = {}
model_kwargs = model_kwargs or {}
if extra_headers:
model_kwargs.update({"extra_headers": extra_headers})
if extra_body:

View File

@@ -1,5 +1,8 @@
from typing import Any
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_default_provider
@@ -13,6 +16,15 @@ from danswer.utils.headers import build_llm_extra_headers
from danswer.utils.long_term_log import LongTermLogger
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
"""Ollama requires us to specify the max context window.
For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
TODO: allow model-specific values to be configured via the UI.
"""
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}
def get_main_llm_from_tuple(
llms: tuple[LLM, LLM],
) -> LLM:
@@ -132,5 +144,6 @@ def get_llm(
temperature=temperature,
custom_config=custom_config,
extra_headers=build_llm_extra_headers(additional_headers),
model_kwargs=_build_extra_model_kwargs(provider),
long_term_logger=long_term_logger,
)

View File

@@ -9,6 +9,7 @@ from pydantic import BaseModel
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from danswer.configs.app_configs import LOG_INDIVIDUAL_MODEL_TOKENS
from danswer.utils.logger import setup_logger
@@ -117,10 +118,19 @@ class LLM(abc.ABC):
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._stream_implementation(
messages = self._stream_implementation(
prompt, tools, tool_choice, structured_response_format
)
tokens = []
for message in messages:
if LOG_INDIVIDUAL_MODEL_TOKENS:
tokens.append(message.content)
yield message
if LOG_INDIVIDUAL_MODEL_TOKENS and tokens:
logger.debug(f"Model Tokens: {tokens}")
@abc.abstractmethod
def _stream_implementation(
self,

View File

@@ -1,3 +1,4 @@
import copy
import io
import json
from collections.abc import Callable
@@ -136,9 +137,11 @@ def translate_history_to_basemessages(
return history_basemessages, history_token_counts
def _process_csv_file(file: InMemoryChatFile) -> str:
# Processes CSV files to show the first 5 rows and max_columns (default 40) columns
def _process_csv_file(file: InMemoryChatFile, max_columns: int = 40) -> str:
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
csv_preview = df.head().to_string()
csv_preview = df.head().to_string(max_cols=max_columns)
file_name_section = (
f"CSV FILE NAME: {file.filename}\n"
@@ -383,6 +386,62 @@ def test_llm(llm: LLM) -> str | None:
return error_msg
def get_model_map() -> dict:
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))
# NOTE: we could add additional models here in the future,
# but for now there is no point. Ollama allows the user to
# to specify their desired max context window, and it's
# unlikely to be standard across users even for the same model
# (it heavily depends on their hardware). For now, we'll just
# rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this.
# for model_name in [
# "llama3.2",
# "llama3.2:1b",
# "llama3.2:3b",
# "llama3.2:11b",
# "llama3.2:90b",
# ]:
# starting_map[f"ollama/{model_name}"] = {
# "max_tokens": 128000,
# "max_input_tokens": 128000,
# "max_output_tokens": 128000,
# }
return starting_map
def _strip_extra_provider_from_model_name(model_name: str) -> str:
return model_name.split("/")[1] if "/" in model_name else model_name
def _strip_colon_from_model_name(model_name: str) -> str:
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
def _find_model_obj(
model_map: dict, provider: str, model_names: list[str | None]
) -> dict | None:
# Filter out None values and deduplicate model names
filtered_model_names = [name for name in model_names if name]
# First try all model names with provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(f"{provider}/{model_name}")
if model_obj:
logger.debug(f"Using model object for {provider}/{model_name}")
return model_obj
# Then try all model names without provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")
return model_obj
return None
def get_llm_max_tokens(
model_map: dict,
model_name: str,
@@ -395,22 +454,22 @@ def get_llm_max_tokens(
return GEN_AI_MAX_TOKENS
try:
model_obj = model_map.get(f"{model_provider}/{model_name}")
if model_obj:
logger.debug(f"Using model object for {model_provider}/{model_name}")
if not model_obj:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")
if not model_obj:
model_name_split = model_name.split("/")
if len(model_name_split) > 1:
model_obj = model_map.get(model_name_split[1])
if model_obj:
logger.debug(f"Using model object for {model_name_split[1]}")
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
model_name
)
model_obj = _find_model_obj(
model_map,
model_provider,
[
model_name,
# Remove leading extra provider. Usually for cases where user has a
# customer model proxy which appends another prefix
extra_provider_stripped_model_name,
# remove :XXXX from the end, if present. Needed for ollama.
_strip_colon_from_model_name(model_name),
_strip_colon_from_model_name(extra_provider_stripped_model_name),
],
)
if not model_obj:
raise RuntimeError(
f"No litellm entry found for {model_provider}/{model_name}"
@@ -486,7 +545,7 @@ def get_max_input_tokens(
# `model_cost` dict is a named public interface:
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
# model_map is litellm.model_cost
litellm_model_map = litellm.model_cost
litellm_model_map = get_model_map()
input_toks = (
get_llm_max_tokens(

View File

@@ -26,6 +26,7 @@ from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import BasicAuthenticationError
from danswer.auth.users import create_danswer_oauth_router
from danswer.auth.users import fastapi_users
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.configs.app_configs import APP_HOST
@@ -44,6 +45,7 @@ from danswer.configs.constants import AuthType
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.db.engine import warm_up_connections
from danswer.server.api_key.api import router as api_key_router
from danswer.server.auth_check import check_router_auth
from danswer.server.danswer_api.ingestion import router as danswer_api_router
from danswer.server.documents.cc_pair import router as cc_pair_router
@@ -280,6 +282,7 @@ def get_application() -> FastAPI:
application, get_full_openai_assistants_api_router()
)
include_router_with_global_prefix_prepended(application, long_term_logs_router)
include_router_with_global_prefix_prepended(application, api_key_router)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step
@@ -323,7 +326,7 @@ def get_application() -> FastAPI:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_oauth_router(
create_danswer_oauth_router(
oauth_client,
auth_backend,
USER_AUTH_SECRET,

View File

@@ -0,0 +1,4 @@
class ModelServerRateLimitError(Exception):
"""
Exception raised for rate limiting errors from the model server.
"""

View File

@@ -1,4 +1,3 @@
import re
import threading
import time
from collections.abc import Callable
@@ -7,6 +6,9 @@ from typing import Any
import requests
from httpx import HTTPError
from requests import JSONDecodeError
from requests import RequestException
from requests import Response
from retry import retry
from danswer.configs.app_configs import LARGE_CHUNK_RATIO
@@ -17,6 +19,9 @@ from danswer.configs.model_configs import (
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.exceptions import (
ModelServerRateLimitError,
)
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
@@ -50,28 +55,6 @@ def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
_INITIAL_FILTER = re.compile(
"["
"\U0000FFF0-\U0000FFFF" # Specials
"\U0001F000-\U0001F9FF" # Emoticons
"\U00002000-\U0000206F" # General Punctuation
"\U00002190-\U000021FF" # Arrows
"\U00002700-\U000027BF" # Dingbats
"]+",
flags=re.UNICODE,
)
def clean_openai_text(text: str) -> str:
# Remove specific Unicode ranges that might cause issues
cleaned = _INITIAL_FILTER.sub("", text)
# Remove any control characters except for newline and tab
cleaned = "".join(ch for ch in cleaned if ch >= " " or ch in "\n\t")
return cleaned
def build_model_server_url(
model_server_host: str,
model_server_port: int,
@@ -122,28 +105,43 @@ class EmbeddingModel:
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
def _make_request() -> EmbedResponse:
def _make_request() -> Response:
response = requests.post(
self.embed_server_endpoint, json=embed_request.model_dump()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
# signify that this is a rate limit error
if response.status_code == 429:
raise ModelServerRateLimitError(response.text)
return EmbedResponse(**response.json())
response.raise_for_status()
return response
# only perform retries for the non-realtime embedding of passages (e.g. for indexing)
final_make_request_func = _make_request
# if the text type is a passage, add some default
# retries + handling for rate limiting
if embed_request.text_type == EmbedTextType.PASSAGE:
return retry(tries=3, delay=5)(_make_request)()
else:
return _make_request()
final_make_request_func = retry(
tries=3,
delay=5,
exceptions=(RequestException, ValueError, JSONDecodeError),
)(final_make_request_func)
# use 10 second delay as per Azure suggestion
final_make_request_func = retry(
tries=10, delay=10, exceptions=ModelServerRateLimitError
)(final_make_request_func)
try:
response = final_make_request_func()
return EmbedResponse(**response.json())
except requests.HTTPError as e:
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
def _batch_encode_texts(
self,
@@ -215,11 +213,6 @@ class EmbeddingModel:
for text in texts
]
if self.provider_type == EmbeddingProvider.OPENAI:
# If the provider is openai, we need to clean the text
# as a temporary workaround for the openai API
texts = [clean_openai_text(text) for text in texts]
batch_size = (
api_embedding_batch_size
if self.provider_type

View File

@@ -7,7 +7,7 @@ from transformers import logging as transformer_logging # type:ignore
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.search.models import InferenceChunk
from danswer.context.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
@@ -131,7 +131,7 @@ def _try_initialize_tokenizer(
return tokenizer
except Exception as hf_error:
logger.warning(
f"Error initializing HuggingFaceTokenizer for {model_name}: {hf_error}"
f"Failed to initialize HuggingFaceTokenizer for {model_name}: {hf_error}"
)
# If both initializations fail, return None

View File

@@ -18,6 +18,11 @@ from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.constants import MessageType
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.models import RerankMetricsContainer
from danswer.context.search.models import RetrievalMetricsContainer
from danswer.context.search.utils import chunks_or_sections_to_search_docs
from danswer.context.search.utils import dedupe_documents
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -42,11 +47,7 @@ from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.one_shot_answer.models import QueryRephrase
from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.search.enums import LLMEvaluationType
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.one_shot_answer.qa_utils import slackify_message_thread
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail
@@ -194,13 +195,22 @@ def stream_answer_objects(
)
prompt = persona.prompts[0]
user_message_str = query_msg.message
# For this endpoint, we only save one user message to the chat session
# However, for slackbot, we want to include the history of the entire thread
if danswerbot_flow:
# Right now, we only support bringing over citations and search docs
# from the last message in the thread, not the entire thread
# in the future, we may want to retrieve the entire thread
user_message_str = slackify_message_thread(query_req.messages)
# Create the first User query message
new_user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=root_message,
prompt_id=query_req.prompt_id,
message=query_msg.message,
token_count=len(llm_tokenizer.encode(query_msg.message)),
message=user_message_str,
token_count=len(llm_tokenizer.encode(user_message_str)),
message_type=MessageType.USER,
db_session=db_session,
commit=True,

View File

@@ -9,12 +9,12 @@ from danswer.chat.models import DanswerContexts
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import QADocsResponse
from danswer.configs.constants import MessageType
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.search.models import ChunkContext
from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
from danswer.context.search.models import ChunkContext
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
class QueryRephrase(BaseModel):
@@ -36,10 +36,6 @@ class PromptConfig(BaseModel):
datetime_aware: bool = True
class DocumentSetConfig(BaseModel):
id: int
class ToolConfig(BaseModel):
id: int

View File

@@ -51,3 +51,31 @@ def combine_message_thread(
total_token_count += message_token_count
return "\n\n".join(message_strs)
def slackify_message(message: ThreadMessage) -> str:
if message.role != MessageType.USER:
return message.message
return f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
if not messages:
return ""
message_strs: list[str] = []
for message in messages:
if message.role == MessageType.USER:
message_text = (
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
)
elif message.role == MessageType.ASSISTANT:
message_text = f"DanswerBot said in Slack:\n{message.message}"
else:
message_text = (
f"{message.role.value.upper()} said in Slack:\n{message.message}"
)
message_strs.append(message_text)
return "\n\n".join(message_strs)

View File

@@ -118,18 +118,6 @@ You should always get right to the point, and never use extraneous language.
"""
# For weak LLM which only takes one chunk and cannot output json
# Also not requiring quotes as it tends to not work
WEAK_LLM_PROMPT = f"""
{{system_prompt}}
{{context_block}}
{{task_prompt}}
{QUESTION_PAT.upper()}
{{user_query}}
""".strip()
# This is only for visualization for the users to specify their own prompts
# The actual flow does not work like this
PARAMATERIZED_PROMPT = f"""

View File

@@ -7,12 +7,12 @@ from langchain_core.messages import BaseMessage
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import LANGUAGE_HINT
from danswer.configs.constants import DocumentSource
from danswer.context.search.models import InferenceChunk
from danswer.db.models import Prompt
from danswer.llm.answering.models import PromptConfig
from danswer.prompts.chat_prompts import ADDITIONAL_INFO
from danswer.prompts.chat_prompts import CITATION_REMINDER
from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger

View File

@@ -1,5 +1,8 @@
import time
import redis
from danswer.db.models import SearchSettings
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
@@ -31,6 +34,44 @@ class RedisConnector:
self.tenant_id, self.id, search_settings_id, self.redis
)
def wait_for_indexing_termination(
self,
search_settings_list: list[SearchSettings],
timeout: float = 15.0,
) -> bool:
"""
Returns True if all indexing for the given redis connector is finished within the given timeout.
Returns False if the timeout is exceeded
This check does not guarantee that current indexings being terminated
won't get restarted midflight
"""
finished = False
start = time.monotonic()
while True:
still_indexing = False
for search_settings in search_settings_list:
redis_connector_index = self.new_index(search_settings.id)
if redis_connector_index.fenced:
still_indexing = True
break
if not still_indexing:
finished = True
break
now = time.monotonic()
if now - start > timeout:
break
time.sleep(1)
continue
return finished
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""

View File

@@ -14,8 +14,9 @@ from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
class RedisConnectorPermissionSyncData(BaseModel):
class RedisConnectorPermissionSyncPayload(BaseModel):
started: datetime | None
celery_task_id: str | None
class RedisConnectorPermissionSync:
@@ -78,14 +79,14 @@ class RedisConnectorPermissionSync:
return False
@property
def payload(self) -> RedisConnectorPermissionSyncData | None:
def payload(self) -> RedisConnectorPermissionSyncPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorPermissionSyncData.model_validate_json(
payload = RedisConnectorPermissionSyncPayload.model_validate_json(
cast(str, fence_str)
)
@@ -93,7 +94,7 @@ class RedisConnectorPermissionSync:
def set_fence(
self,
payload: RedisConnectorPermissionSyncData | None,
payload: RedisConnectorPermissionSyncPayload | None,
) -> None:
if not payload:
self.redis.delete(self.fence_key)
@@ -162,6 +163,12 @@ class RedisConnectorPermissionSync:
return len(async_results)
def reset(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}"

View File

@@ -1,11 +1,18 @@
from datetime import datetime
from typing import cast
import redis
from celery import Celery
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
class RedisConnectorExternalGroupSyncPayload(BaseModel):
started: datetime | None
celery_task_id: str | None
class RedisConnectorExternalGroupSync:
"""Manages interactions with redis for external group syncing tasks. Should only be accessed
through RedisConnector."""
@@ -68,12 +75,29 @@ class RedisConnectorExternalGroupSync:
return False
def set_fence(self, value: bool) -> None:
if not value:
@property
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorExternalGroupSyncPayload.model_validate_json(
cast(str, fence_str)
)
return payload
def set_fence(
self,
payload: RedisConnectorExternalGroupSyncPayload | None,
) -> None:
if not payload:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, 0)
self.redis.set(self.fence_key, payload.model_dump_json())
@property
def generator_complete(self) -> int | None:

View File

@@ -29,6 +29,8 @@ class RedisConnectorIndex:
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
def __init__(
self,
tenant_id: str | None,
@@ -51,6 +53,7 @@ class RedisConnectorIndex:
self.generator_lock_key = (
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
)
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
@classmethod
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
@@ -92,6 +95,18 @@ class RedisConnectorIndex:
self.redis.set(self.fence_key, payload.model_dump_json())
def terminating(self, celery_task_id: str) -> bool:
if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
return True
return False
def set_terminate(self, celery_task_id: str) -> None:
"""This sets a signal. It does not block!"""
# We shouldn't need very long to terminate the spawned task.
# 10 minute TTL is good.
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
def set_generator_complete(self, payload: int | None) -> None:
if not payload:
self.redis.delete(self.generator_complete_key)

View File

@@ -1,12 +1,12 @@
import re
from danswer.chat.models import SectionRelevancePiece
from danswer.context.search.models import InferenceSection
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.prompts.agentic_evaluation import AGENTIC_SEARCH_SYSTEM_PROMPT
from danswer.prompts.agentic_evaluation import AGENTIC_SEARCH_USER_PROMPT
from danswer.search.models import InferenceSection
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -1,9 +1,9 @@
# NOTE No longer used. This needs to be revisited later.
import re
from collections.abc import Iterator
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import DISABLE_LLM_QUERY_ANSWERABILITY
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
@@ -46,7 +46,7 @@ def extract_answerability_bool(model_raw: str) -> bool:
def get_query_answerability(
user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY
user_query: str, skip_check: bool = False
) -> tuple[str, bool]:
if skip_check:
return "Query Answerability Evaluation feature is turned off", True
@@ -67,7 +67,7 @@ def get_query_answerability(
def stream_query_answerability(
user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY
user_query: str, skip_check: bool = False
) -> Iterator[str]:
if skip_check:
yield get_json_line(

View File

@@ -33,6 +33,7 @@ from danswer.server.documents.models import ConnectorBase
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
from danswer.utils.variable_functionality import fetch_versioned_implementation
from ee.danswer.configs.app_configs import INTEGRATION_TEST_MODE
logger = setup_logger()
@@ -127,6 +128,9 @@ def seed_initial_documents(
- Indexing the documents into Vespa
- Create a fake index attempt with fake times
"""
if INTEGRATION_TEST_MODE:
return
logger.info("Seeding initial documents")
kv_store = get_kv_store()

View File

@@ -5,6 +5,7 @@ from danswer.configs.chat_configs import INPUT_PROMPT_YAML
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.context.search.enums import RecencyBiasSetting
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.models import DocumentSet as DocumentSetDBModel
@@ -14,7 +15,6 @@ from danswer.db.models import Tool as ToolDBModel
from danswer.db.persona import get_prompt_by_name
from danswer.db.persona import upsert_persona
from danswer.db.persona import upsert_prompt
from danswer.search.enums import RecencyBiasSetting
def load_prompts_from_yaml(
@@ -81,6 +81,7 @@ def load_personas_from_yaml(
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)

View File

@@ -6,6 +6,7 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi.responses import JSONResponse
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@@ -37,7 +38,9 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import count_index_attempts_for_connector
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
from danswer.db.models import SearchSettings
from danswer.db.models import User
from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
@@ -158,7 +161,19 @@ def update_cc_pair_status(
status_update_request: CCStatusUpdateRequest,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""This method may wait up to 30 seconds if pausing the connector due to the need to
terminate tasks in progress. Tasks are not guaranteed to terminate within the
timeout.
Returns HTTPStatus.OK if everything finished.
Returns HTTPStatus.ACCEPTED if the connector is being paused, but background tasks
did not finish within the timeout.
"""
WAIT_TIMEOUT = 15.0
still_terminating = False
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
@@ -173,10 +188,76 @@ def update_cc_pair_status(
)
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
cancel_indexing_attempts_past_model(db_session)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
redis_connector.stop.set_fence(True)
while True:
logger.debug(
f"Wait for indexing soft termination starting: cc_pair={cc_pair_id}"
)
wait_succeeded = redis_connector.wait_for_indexing_termination(
search_settings_list, WAIT_TIMEOUT
)
if wait_succeeded:
logger.debug(
f"Wait for indexing soft termination succeeded: cc_pair={cc_pair_id}"
)
break
logger.debug(
"Wait for indexing soft termination timed out. "
f"Moving to hard termination: cc_pair={cc_pair_id} timeout={WAIT_TIMEOUT:.2f}"
)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings.id
)
if not redis_connector_index.fenced:
continue
index_payload = redis_connector_index.payload
if not index_payload:
continue
if not index_payload.celery_task_id:
continue
# Revoke the task to prevent it from running
primary_app.control.revoke(index_payload.celery_task_id)
# If it is running, then signaling for termination will get the
# watchdog thread to kill the spawned task
redis_connector_index.set_terminate(index_payload.celery_task_id)
logger.debug(
f"Wait for indexing hard termination starting: cc_pair={cc_pair_id}"
)
wait_succeeded = redis_connector.wait_for_indexing_termination(
search_settings_list, WAIT_TIMEOUT
)
if wait_succeeded:
logger.debug(
f"Wait for indexing hard termination succeeded: cc_pair={cc_pair_id}"
)
break
logger.debug(
f"Wait for indexing hard termination timed out: cc_pair={cc_pair_id}"
)
still_terminating = True
break
finally:
redis_connector.stop.set_fence(False)
update_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -185,6 +266,18 @@ def update_cc_pair_status(
db_session.commit()
if still_terminating:
return JSONResponse(
status_code=HTTPStatus.ACCEPTED,
content={
"message": "Request accepted, background task termination still in progress"
},
)
return JSONResponse(
status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)}
)
@router.put("/admin/cc-pair/{cc_pair_id}/name")
def update_cc_pair_name(
@@ -267,9 +360,9 @@ def prune_cc_pair(
)
logger.info(
f"Pruning cc_pair: cc_pair_id={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"Pruning cc_pair: cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
tasks_created = try_creating_prune_generator_task(

View File

@@ -17,9 +17,9 @@ from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
from danswer.background.celery.tasks.indexing.tasks import try_creating_indexing_task
from danswer.background.celery.versioned_apps.primary import app as primary_app
from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
from danswer.connectors.google_utils.google_auth import (
@@ -59,6 +59,7 @@ from danswer.db.connector import delete_connector
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import fetch_connectors
from danswer.db.connector import get_connector_credential_ids
from danswer.db.connector import mark_ccpair_with_indexing_trigger
from danswer.db.connector import update_connector
from danswer.db.connector_credential_pair import add_credential_to_connector
from danswer.db.connector_credential_pair import get_cc_pair_groups_for_ids
@@ -74,6 +75,7 @@ from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.enums import IndexingMode
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_latest_index_attempts
@@ -86,7 +88,6 @@ from danswer.db.search_settings import get_secondary_search_settings
from danswer.file_store.file_store import get_default_file_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import AuthStatus
from danswer.server.documents.models import AuthUrl
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
@@ -792,12 +793,10 @@ def connector_run_once(
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> StatusResponse[list[int]]:
) -> StatusResponse[int]:
"""Used to trigger indexing on a set of cc_pairs associated with a
single connector."""
r = get_redis_client(tenant_id=tenant_id)
connector_id = run_info.connector_id
specified_credential_ids = run_info.credential_ids
@@ -843,54 +842,41 @@ def connector_run_once(
)
]
search_settings = get_current_search_settings(db_session)
connector_credential_pairs = [
get_connector_credential_pair(connector_id, credential_id, db_session)
for credential_id in credential_ids
if credential_id not in skipped_credentials
]
index_attempt_ids = []
num_triggers = 0
for cc_pair in connector_credential_pairs:
if cc_pair is not None:
attempt_id = try_creating_indexing_task(
primary_app,
cc_pair,
search_settings,
run_info.from_beginning,
db_session,
r,
tenant_id,
indexing_mode = IndexingMode.UPDATE
if run_info.from_beginning:
indexing_mode = IndexingMode.REINDEX
mark_ccpair_with_indexing_trigger(cc_pair.id, indexing_mode, db_session)
num_triggers += 1
logger.info(
f"connector_run_once - marking cc_pair with indexing trigger: "
f"connector={run_info.connector_id} "
f"cc_pair={cc_pair.id} "
f"indexing_trigger={indexing_mode}"
)
if attempt_id:
logger.info(
f"connector_run_once - try_creating_indexing_task succeeded: "
f"connector={run_info.connector_id} "
f"cc_pair={cc_pair.id} "
f"attempt={attempt_id} "
)
index_attempt_ids.append(attempt_id)
else:
logger.info(
f"connector_run_once - try_creating_indexing_task failed: "
f"connector={run_info.connector_id} "
f"cc_pair={cc_pair.id}"
)
if not index_attempt_ids:
msg = "No new indexing attempts created, indexing jobs are queued or running."
logger.info(msg)
raise HTTPException(
status_code=400,
detail=msg,
)
# run the beat task to pick up the triggers immediately
primary_app.send_task(
"check_for_indexing",
priority=DanswerCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
)
msg = f"Successfully created {len(index_attempt_ids)} index attempts. {index_attempt_ids}"
msg = f"Marked {num_triggers} index attempts with indexing triggers."
return StatusResponse(
success=True,
message=msg,
data=index_attempt_ids,
data=num_triggers,
)

View File

@@ -5,6 +5,10 @@ from fastapi import Query
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.context.search.models import IndexFilters
from danswer.context.search.preprocessing.access_filters import (
build_access_filters_for_user,
)
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
@@ -12,8 +16,6 @@ from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import IndexFilters
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.server.documents.models import ChunkInfo
from danswer.server.documents.models import DocumentInfo

Some files were not shown because too many files have changed in this diff Show More